Exemplo n.º 1
0
class SoftActorCritic(RL_algorithm):

    def __init__(self, config, env, replay, networks):
        """ Bascally a wrapper class for SAC from rlkit.

        Args:
            config: Configuration dictonary
            env: Environment
            replay: Replay buffer
            networks: dict containing two sub-dicts, 'individual' and 'population'
                which contain the networks.

        """
        super().__init__(config, env, replay, networks)

        self._variant_pop = config['rl_algorithm_config']['algo_params_pop']
        self._variant_spec = config['rl_algorithm_config']['algo_params']

        self._ind_qf1 = networks['individual']['qf1']
        self._ind_qf2 = networks['individual']['qf2']
        self._ind_qf1_target = networks['individual']['qf1_target']
        self._ind_qf2_target = networks['individual']['qf2_target']
        self._ind_policy = networks['individual']['policy']

        self._pop_qf1 = networks['population']['qf1']
        self._pop_qf2 = networks['population']['qf2']
        self._pop_qf1_target = networks['population']['qf1_target']
        self._pop_qf2_target = networks['population']['qf2_target']
        self._pop_policy = networks['population']['policy']

        self._batch_size = config['rl_algorithm_config']['batch_size']
        self._nmbr_indiv_updates = config['rl_algorithm_config']['indiv_updates']
        self._nmbr_pop_updates = config['rl_algorithm_config']['pop_updates']

        self._algorithm_ind = SoftActorCritic_rlkit(
            env=self._env,
            policy=self._ind_policy,
            qf1=self._ind_qf1,
            qf2=self._ind_qf2,
            target_qf1=self._ind_qf1_target,
            target_qf2=self._ind_qf2_target,
            use_automatic_entropy_tuning = False,
            **self._variant_spec
        )

        self._algorithm_pop = SoftActorCritic_rlkit(
            env=self._env,
            policy=self._pop_policy,
            qf1=self._pop_qf1,
            qf2=self._pop_qf2,
            target_qf1=self._pop_qf1_target,
            target_qf2=self._pop_qf2_target,
            use_automatic_entropy_tuning = False,
            **self._variant_pop
        )

        # self._algorithm_ind.to(ptu.device)
        # self._algorithm_pop.to(ptu.device)

    def episode_init(self):
        """ Initializations to be done before the first episode.

        In this case basically creates a fresh instance of SAC for the
        individual networks and copies the values of the target network.
        """
        self._algorithm_ind = SoftActorCritic_rlkit(
            env=self._env,
            policy=self._ind_policy,
            qf1=self._ind_qf1,
            qf2=self._ind_qf2,
            target_qf1=self._ind_qf1_target,
            target_qf2=self._ind_qf2_target,
            use_automatic_entropy_tuning = False,
            # alt_alpha = self._alt_alpha,
            **self._variant_spec
        )
        if self._config['rl_algorithm_config']['copy_from_gobal']:
            utils.copy_pop_to_ind(networks_pop=self._networks['population'], networks_ind=self._networks['individual'])
        # We have only to do this becasue the version of rlkit which we use
        # creates internally a target network
        # vf_dict = self._algorithm_pop.target_vf.state_dict()
        # self._algorithm_ind.target_vf.load_state_dict(vf_dict)
        # self._algorithm_ind.target_vf.eval()
        # self._algorithm_ind.to(ptu.device)

    def single_train_step(self, train_ind=True, train_pop=False):
        """ A single trianing step.

        Args:
            train_ind: Boolean. If true the individual networks will be trained.
            train_pop: Boolean. If true the population networks will be trained.
        """
        if train_ind:
          # Get only samples from the species buffer
          self._replay.set_mode('species')
          # self._algorithm_ind.num_updates_per_train_call = self._variant_spec['num_updates_per_epoch']
          # self._algorithm_ind._try_to_train()
          for _ in range(self._nmbr_indiv_updates):
              batch = self._replay.random_batch(self._batch_size)
              self._algorithm_ind.train(batch)

        if train_pop:
          # Get only samples from the population buffer
          self._replay.set_mode('population')
          # self._algorithm_pop.num_updates_per_train_call = self._variant_pop['num_updates_per_epoch']
          # self._algorithm_pop._try_to_train()
          for _ in range(self._nmbr_pop_updates):
              batch = self._replay.random_batch(self._batch_size)
              self._algorithm_pop.train(batch)

    @staticmethod
    def create_networks(env, config):
        """ Creates all networks necessary for SAC.

        These networks have to be created before instantiating this class and
        used in the constructor.

        Args:
            config: A configuration dictonary containing population and
                individual networks

        Returns:
            A dictonary which contains the networks.
        """
        network_dict = {
            'individual' : SoftActorCritic._create_networks(env=env, config=config),
            'population' : SoftActorCritic._create_networks(env=env, config=config),
        }
        return network_dict

    @staticmethod
    def _create_networks(env, config):
        """ Creates all networks necessary for SAC.

        These networks have to be created before instantiating this class and
        used in the constructor.

        TODO: Maybe this should be reworked one day...

        Args:
            config: A configuration dictonary.

        Returns:
            A dictonary which contains the networks.
        """
        obs_dim = int(np.prod(env.observation_space.shape))
        action_dim = int(np.prod(env.action_space.shape))
        net_size = config['rl_algorithm_config']['net_size']
        hidden_sizes = [net_size] * config['rl_algorithm_config']['network_depth']
        # hidden_sizes = [net_size, net_size, net_size]
        qf1 = FlattenMlp(
            hidden_sizes=hidden_sizes,
            input_size=obs_dim + action_dim,
            output_size=1,
        ).to(device=ptu.device)
        qf2 = FlattenMlp(
            hidden_sizes=hidden_sizes,
            input_size=obs_dim + action_dim,
            output_size=1,
        ).to(device=ptu.device)
        qf1_target = FlattenMlp(
            hidden_sizes=hidden_sizes,
            input_size=obs_dim + action_dim,
            output_size=1,
        ).to(device=ptu.device)
        qf2_target = FlattenMlp(
            hidden_sizes=hidden_sizes,
            input_size=obs_dim + action_dim,
            output_size=1,
        ).to(device=ptu.device)
        policy = TanhGaussianPolicy(
            hidden_sizes=hidden_sizes,
            obs_dim=obs_dim,
            action_dim=action_dim,
        ).to(device=ptu.device)

        clip_value = 1.0
        for p in qf1.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))
        for p in qf2.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))
        for p in policy.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -clip_value, clip_value))

        return {'qf1' : qf1, 'qf2' : qf2, 'qf1_target' : qf1_target, 'qf2_target' : qf2_target, 'policy' : policy}

    @staticmethod
    def get_q_network(networks):
        """ Returns the q network from a dict of networks.

        This method extracts the q-network from the dictonary of networks
        created by the function create_networks.

        Args:
            networks: Dict containing the networks.

        Returns:
            The q-network as torch object.
        """
        return networks['qf1']

    @staticmethod
    def get_policy_network(networks):
        """ Returns the policy network from a dict of networks.

        This method extracts the policy network from the dictonary of networks
        created by the function create_networks.

        Args:
            networks: Dict containing the networks.

        Returns:
            The policy network as torch object.
        """
        return networks['policy']
Exemplo n.º 2
0
class SACDeadTrainer(TorchTrainer):
    def __init__(
        self,
        env,
        qf1,
        qf2,
        target_qf1,
        target_qf2,
        qf_dead,
        global_policy,
        discount=0.99,
        reward_scale=1.0,
        policy_lr=1e-3,
        qf_lr=1e-3,
        optimizer_class=RAdam,  # optim.Adam
        soft_target_tau=1e-2,
        target_update_period=1,
        plotter=None,
        render_eval_paths=False,
        use_automatic_entropy_tuning=True,
        target_entropy=None,
    ):
        super().__init__()
        sac_trainer_params = {
            'discount': discount,
            'reward_scale': reward_scale,
            'policy_lr': policy_lr,
            'qf_lr': qf_lr,
            'optimizer_class': optimizer_class,
            'soft_target_tau': soft_target_tau,
            'target_update_period': target_update_period,
            'plotter': plotter,
            'render_eval_paths': render_eval_paths,
            'use_automatic_entropy_tuning': use_automatic_entropy_tuning,
            'target_entropy': target_entropy
        }
        dead_trainer_params = {
            'pass_per_iteration': 1,
        }
        self.sacTrainer = SACTrainer(env=env,
                                     policy=global_policy.tanhGaussianPolicy,
                                     qf1=qf1,
                                     qf2=qf2,
                                     target_qf1=target_qf1,
                                     target_qf2=target_qf2,
                                     **sac_trainer_params)

        self.deadTrainer = DeadTrainer(
            policy_dead=global_policy.deadPredictionPolicy,
            qf_dead=qf_dead,
            **dead_trainer_params)

        ###################################

        self.env = env
        self.policy = global_policy

        self.plotter = plotter
        self.render_eval_paths = render_eval_paths

        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

    def train(self, batch):
        """
        :param batch: dict with 3 fields 'normal', 'safe' and 'danger'
         'normal' part is using for
        :return:
        """
        batch_normal = batch['normal']
        #batch_dead = {key: np.concatenate((batch['dead'][key], batch['safe'][key])) for key in batch['dead']}

        self.sacTrainer.train(batch_normal)

        self.deadTrainer.train(np_batch=batch['safe'],
                               np_batch_dead=batch['dead'])

        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            self.eval_statistics = OrderedDict({
                **self.sacTrainer.eval_statistics,
                **self.deadTrainer.eval_statistics
            })

        self._n_train_steps_total += 1

    def train_from_torch(self, batch):
        assert 'Blank method. Should not be called' == ''

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self.sacTrainer.end_epoch(epoch)
        self.deadTrainer.end_epoch(epoch)
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return self.sacTrainer.networks + self.deadTrainer.networks

    @property
    def optimizers(self) -> Iterable[Optimizer]:
        return self.sacTrainer.optimizers + self.deadTrainer.optimizers

    def get_snapshot(self):
        d1 = self.sacTrainer.get_snapshot()
        d2 = self.deadTrainer.get_snapshot()
        return {**d1, **d2}
Exemplo n.º 3
0
class SAC(Agent):
    def __init__(self, env, eval_env, mem, nets, train_step_params):
        super().__init__(env, eval_env, mem, nets, train_step_params)
        self._mem = mem

        self._env = env
        self._eval_env = eval_env

        self._policy_net, self._q1_net, self._q2_net, self._target_q1_net,\
        self._target_q2_net = nets['policy_net'], nets['q1_net'], nets['q2_net'],\
                              nets['target_q1_net'], nets['target_q2_net']

        self._train_step_params = train_step_params

        self._alg = SACTrainer(env=self._env,
                               policy=self._policy_net,
                               qf1=self._q1_net,
                               qf2=self._q2_net,
                               target_qf1=self._target_q1_net,
                               target_qf2=self._target_q2_net,
                               **train_step_params)

    def _train_step(self, n_trains_per_step, batch_size):
        if self._mem.num_steps_can_sample() < batch_size:
            return

        for _ in range(n_trains_per_step):
            batch = self._mem.random_batch(batch_size)
            self._alg.train(batch)

    def learn(self, results_path, seed_mem=True, **kwargs):
        replay_sample_ratio = kwargs['replay_sample_ratio']
        n_epochs = kwargs['n_epochs']
        episode_len = kwargs['episode_len']
        eval_episode_len = kwargs['eval_episode_len']
        start_steps = kwargs['start_steps']
        n_trains_per_step = kwargs['n_trains_per_step']
        eval_interval = kwargs['eval_interval']
        batch_size = kwargs['train_batch_size']
        checkpoint_interval = kwargs['checkpoint_interval']
        log_interval = kwargs['log_interval']
        artifact_path = results_path + f'artifact'

        # TODO: log summary of current config

        if seed_mem:
            s = self._env.reset()
            for k in range(start_steps):
                a = self._env.action_space.sample()
                ns, r, done, _ = self._env.step(a)

                self._mem.add_sample(observation=s,
                                     action=a,
                                     reward=r,
                                     next_observation=ns,
                                     terminal=1 if done else 0,
                                     env_info=dict())
                if done:
                    s = self._env.reset()
                else:
                    s = ns

        # TODO: use RAY Sampler for parallel simulation sampling
        train_rt = Agent.results_tracker(id='train_performance')

        s = self._env.reset()
        # TODO: track and log policy loss
        for _ in range(n_epochs):
            for i in range(1, episode_len + 1):
                a = self.pred(s)
                ns, r, done, _ = self._env.step(a)

                self.life_tracker['total_n_train_steps'] += 1
                train_rt['train_interval_timesteps'] += 1
                train_rt['train_rewards'].append(r)

                self._mem.add_sample(observation=s,
                                     action=a,
                                     reward=r,
                                     next_observation=ns,
                                     terminal=1.0 if done else 0.0,
                                     env_info=dict())

                if i % replay_sample_ratio == 0:
                    self._train_step(n_trains_per_step, batch_size)
                    self.life_tracker['total_n_train_batches'] += 1

                if i % checkpoint_interval == 0:
                    self.save(artifact_path)

                if done:
                    s = self._env.reset()

                    self.life_tracker['total_n_train_episodes'] += 1
                    train_rt['n_train_episodes_since_last_log'] += 1

                    if self.life_tracker[
                            'total_n_train_episodes'] % log_interval == 0:
                        if self.life_tracker[
                                'total_n_train_episodes'] % eval_interval == 0:
                            self.life_tracker['total_n_evals'] += 1
                            train_rt[
                                'eval_interval_timesteps'] = eval_episode_len

                            s = self._eval_env.reset()
                            for _ in range(eval_episode_len):
                                a = self.pred(s)

                                ns, r, done, _ = self._eval_env.step(a)
                                train_rt['eval_rewards'].append(r)

                                if done:
                                    s = self._eval_env.reset()
                                else:
                                    s = ns

                        self.log_performance(train_rt)
                        train_rt = Agent.results_tracker(
                            id='train_performance')
                else:
                    s = ns

    def pred(self, state, deterministic=False):
        state = torch.from_numpy(state).float().to(torch_util.device)
        with torch.no_grad():
            return self._policy_net(
                state, deterministic=deterministic)[0].cpu().detach().numpy()

    def save(self, filepath):
        if not os.path.isdir(filepath):
            os.makedirs(filepath)

        nets = [
            self._policy_net, self._q1_net, self._q2_net, self._target_q1_net,
            self._target_q2_net
        ]
        net_fps = [
            'policy-net.pt', 'q1-net.pt', 'q2-net.pt', 'target-q1-net.pt',
            'target-q2-net.pt'
        ]
        for i, fn in enumerate(net_fps):
            torch.save(nets[i], f'{filepath}{os.sep}{fn}')

        comps = [self._train_step_params]
        comp_fps = ['train-step-params.pkl']
        for i, fn in enumerate(comp_fps):
            with open(f'{filepath}{os.sep}{fn}', 'wb') as f:
                pickle.dump(comps[i], f)