예제 #1
0
파일: mbpo.py 프로젝트: anyboby/mbpo
    def __init__(
        self,
        training_environment,
        evaluation_environment,
        policy,
        Qs,
        pool,
        static_fns,
        plotter=None,
        tf_summaries=False,
        lr=3e-4,
        reward_scale=1.0,
        target_entropy='auto',
        discount=0.99,
        tau=5e-3,
        target_update_interval=1,
        action_prior='uniform',
        reparameterize=False,
        store_extra_policy_info=False,
        deterministic=False,
        model_train_freq=250,
        num_networks=7,
        num_elites=5,
        model_retain_epochs=20,
        load_model_dir=None,
        rollout_batch_size=100e3,
        real_ratio=0.1,
        rollout_schedule=[20, 100, 1, 1],
        hidden_dim=200,
        max_model_t=None,
        **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """

        super(MBPO, self).__init__(**kwargs)

        obs_dim = np.prod(training_environment.observation_space.shape)
        act_dim = np.prod(training_environment.action_space.shape)

        self._model_params = dict(obs_dim=obs_dim,
                                  act_dim=act_dim,
                                  hidden_dim=hidden_dim,
                                  num_networks=num_networks,
                                  num_elites=num_elites)
        if load_model_dir is not None:
            self._model_params['load_model'] = True
            self._model_params['model_dir'] = load_model_dir
        self._model = construct_model(**self._model_params)

        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model, self._static_fns)

        self._rollout_schedule = rollout_schedule
        self._max_model_t = max_model_t

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('[ MBPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        ### @ anyboby fixed pool size, reallocate causes memory leak
        obs_space = self._pool._observation_space
        act_space = self._pool._action_space
        rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq
        model_steps_per_epoch = int(self._rollout_schedule[-1] *
                                    rollouts_per_epoch)
        mpool_size = self._model_retain_epochs * model_steps_per_epoch

        self._model_pool = SimpleReplayPool(obs_space, act_space, mpool_size)

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        self._build()
예제 #2
0
    def __init__(
            self,
            training_environment,
            evaluation_environment,
            policy,
            Qs,
            pool,
            static_fns,
            plotter=None,
            tf_summaries=False,

            lr=3e-4,
            reward_scale=1.0,
            target_entropy='auto',
            discount=0.99,
            tau=5e-3,
            target_update_interval=1,
            action_prior='uniform',
            reparameterize=False,
            store_extra_policy_info=False,

            deterministic=False,
            model_train_freq=250,
            model_train_slower=1,
            num_networks=7,
            num_elites=5,
            num_Q_elites=2, # The num of Q ensemble is set in command line
            model_retain_epochs=20,
            rollout_batch_size=100e3,
            real_ratio=0.1,
            critic_same_as_actor=True,
            rollout_schedule=[20,100,1,1],
            hidden_dim=200,
            max_model_t=None,
            dir_name=None,
            evaluate_explore_freq=0,
            num_Q_per_grp=2,
            num_Q_grp=1,
            cross_grp_diff_batch=False,

            model_load_dir=None,
            model_load_index=None,
            model_log_freq=0,
            **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
            critic_same_as_actor ('bool'): If True, use the same sampling schema
                (model free or model based) as the actor in critic training. 
                Otherwise, use model free sampling to train critic.
        """

        super(MBPO, self).__init__(**kwargs)

        if training_environment.unwrapped.spec.id.find("Fetch") != -1:
            # Fetch env
            obs_dim = sum([i.shape[0] for i in training_environment.observation_space.spaces.values()]) 
            self.multigoal = 1
        else:
            obs_dim = np.prod(training_environment.observation_space.shape)
        # print("====", obs_dim, "========")

        act_dim = np.prod(training_environment.action_space.shape)

        # TODO: add variable scope to directly extract model parameters
        self._model_load_dir = model_load_dir
        print("============Model dir: ", self._model_load_dir)
        if model_load_index:
            latest_model_index = model_load_index
        else:
            latest_model_index = self._get_latest_index()
        self._model = construct_model(obs_dim=obs_dim, act_dim=act_dim, hidden_dim=hidden_dim, num_networks=num_networks, num_elites=num_elites,
                                      model_dir=self._model_load_dir, model_load_timestep=latest_model_index, load_model=True if model_load_dir else False)
        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model, self._static_fns)

        model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self._model.name)
        all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        self._rollout_schedule = rollout_schedule
        self._max_model_t = max_model_t

        # self._model_pool_size = model_pool_size
        # print('[ MBPO ] Model pool size: {:.2E}'.format(self._model_pool_size))
        # self._model_pool = SimpleReplayPool(pool._observation_space, pool._action_space, self._model_pool_size)

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_ensemble = len(Qs)
        self._Q_elites = num_Q_elites
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto'
            else target_entropy)
        print('[ MBPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        # self._critic_train_repeat = kwargs["critic_train_repeat"]
        # actor UTD should be n times larger or smaller than critic UTD
        assert self._actor_train_repeat % self._critic_train_repeat == 0 or \
               self._critic_train_repeat % self._actor_train_repeat == 0

        self._critic_train_freq = self._n_train_repeat // self._critic_train_repeat
        self._actor_train_freq = self._n_train_repeat // self._actor_train_repeat
        self._critic_same_as_actor = critic_same_as_actor
        self._model_train_slower = model_train_slower
        self._origin_model_train_epochs = 0

        self._dir_name = dir_name
        self._evaluate_explore_freq = evaluate_explore_freq

        # Inter-group Qs are trained with the same data; Cross-group Qs different.
        self._num_Q_per_grp = num_Q_per_grp
        self._num_Q_grp = num_Q_grp
        self._cross_grp_diff_batch = cross_grp_diff_batch

        self._model_log_freq = model_log_freq

        self._build()
예제 #3
0
    def __init__(self,
                 true_environment,
                 static_fns,
                 num_networks=7,
                 num_elites=5,
                 hidden_dims=(220, 220, 220),
                 dyn_discount=1,
                 cost_m_discount=1,
                 cares_about_cost=False,
                 session=None):

        self.domain = true_environment.domain if true_environment.domain else ""
        self.env = true_environment
        self.obs_dim = np.prod(self.observation_space.shape)
        self.act_dim = np.prod(self.action_space.shape)
        self.active_obs_dim = int(self.obs_dim / self.env.stacks)
        self._session = session
        self.cares_about_cost = cares_about_cost
        self.rew_dim = 1
        self.cost_classes = [0, 1]

        self.num_networks = num_networks
        self.num_elites = num_elites

        self.static_fns = static_fns

        target_weight_f = WEIGHTS_PER_DOMAIN.get(self.domain, None)
        self.target_weights = target_weight_f(
            self.obs_dim) if target_weight_f else None

        self._static_fns = static_fns  # termination functions for the envs (model can't simulate those)
        self.static_r = self.static_fns.reward_f if "reward_f" in dir(
            self.static_fns) else False
        self.static_done = self.static_fns.termination_fn if "termination_fn" in dir(
            self.static_fns) else False
        self.static_c = self.static_fns.cost_f if "cost_f" in dir(
            self.static_fns) else False
        self.post_f = self.static_fns.post_f if "post_f" in dir(
            self.static_fns) else False
        self.prior_f = self.static_fns.prior_f if "prior_f" in dir(
            self.static_fns) else False
        self.prior_dim = self.static_fns.PRIOR_DIM if "PRIOR_DIM" in dir(
            self.static_fns) else 0
        #### create fake env from model

        input_dim_dyn = self.obs_dim + self.prior_dim + self.act_dim
        input_dim_c = 2 * self.obs_dim + self.act_dim + self.prior_dim
        output_dim_dyn = self.active_obs_dim + self.rew_dim
        self.dyn_loss = 'MSPE'

        self._model = construct_model(in_dim=input_dim_dyn,
                                      out_dim=output_dim_dyn,
                                      name='BNN',
                                      loss=self.dyn_loss,
                                      hidden_dims=hidden_dims,
                                      lr=1e-3,
                                      num_networks=num_networks,
                                      num_elites=num_elites,
                                      weighted=dyn_discount < 1,
                                      use_scaler_in=True,
                                      use_scaler_out=True,
                                      decay=1e-6,
                                      max_logvar=.5,
                                      min_logvar=-10,
                                      session=self._session)
        if self.cares_about_cost and not self.static_c:

            self.cost_m_loss = 'MSE'
            output_activation = 'softmax' if self.cost_m_loss == 'CE' else None

            self._cost_model = construct_model(
                in_dim=input_dim_c,
                out_dim=2 if self.cost_m_loss == 'CE' else 1,
                loss=self.cost_m_loss,
                name='CostNN',
                hidden_dims=(64, 64),
                lr=1e-4,
                output_activation=output_activation,
                num_networks=num_networks,
                num_elites=num_elites,
                weighted=cost_m_discount < 1,
                use_scaler_in=False,
                use_scaler_out=False,
                decay=1e-6,
                session=self._session)

        else:
            self._cost_model = None

        self.dyn_target_var_rm = 1
예제 #4
0
    def __init__(
        self,
        obs_space,
        act_space,
        session,
        logger=None,
        *args,
        **kwargs,
    ):

        # ___________________________________________ #
        #                   Params                    #
        # ___________________________________________ #
        self.hidden_sizes_a = kwargs.get('a_hidden_layer_sizes')
        self.hidden_sizes_c = kwargs.get('vf_hidden_layer_sizes')

        self.dyn_ensemble_size = kwargs.get('dyn_ensemble_size', 1)
        self.vf_lr = kwargs.get('vf_lr', 1e-4)

        self.vf_epochs = kwargs.get('vf_epochs', 10)
        self.vf_batch_size = kwargs.get('vf_batch_size', 64)
        self.vf_holdout = 0.1
        self.vf_train_kwargs = dict(batch_size=self.vf_batch_size,
                                    min_epoch_before_break=self.vf_epochs,
                                    max_epochs=self.vf_epochs,
                                    holdout_ratio=self.vf_holdout)

        self.vf_ensemble = kwargs.get('vf_ensemble_size', 5)
        self.vf_elites = kwargs.get('vf_elites', 3)
        self.v_logit_bias = (kwargs.get('v_logit_bias', 0))
        self.vc_logit_bias = kwargs.get('vc_logit_bias', 0)
        self.vf_activation = kwargs.get('vf_activation', 'ReLU')
        self.vf_loss = kwargs.get('vf_loss', 'MSE')
        self.vf_decay = kwargs.get('vf_decay', 1e-6)
        self.gaussian_vf = self.vf_loss == 'NLL'

        self.vf_cliploss = kwargs.get('vf_clipping', False)
        self.vf_cliprange = kwargs.get('vf_cliprange', 0.1)
        self.cvf_cliprange = kwargs.get('cvf_cliprange', 0.1)
        self.vf_var_corr = kwargs.get('vf_var_corr', False)

        self.ent_reg = kwargs.get('ent_reg', 0.0)
        self.cost_lim = kwargs.get('cost_lim', 25)
        self.real_c_buffer = [self.cost_lim] * 300

        self.target_kl = kwargs.get('target_kl', 0.01)
        self.cost_lam = kwargs.get('cost_lam', 0.97)
        self.cost_gamma = kwargs.get('cost_gamma', 0.99)
        self.lam = kwargs.get('lam', 0.97)
        self.gamma = kwargs.get('discount', 0.99)

        self.max_path_length = kwargs.get('max_path_length', 1)
        #usually not deterministic, but give the option for eval runs
        self._deterministic = False

        cpo_kwargs = dict(
            reward_penalized=False,  # Irrelevant in CPO
            objective_penalized=False,  # Irrelevant in CPO
            learn_penalty=False,  # Irrelevant in CPO
            penalty_param_loss=False,  # Irrelevant in CPO
            learn_margin=True,  #learn_margin=True,
            c_gamma=self.cost_gamma,
            max_path_length=self.max_path_length)

        # ________________________________ #
        #       Cpo agent and logger       #
        # ________________________________ #

        log_dir = kwargs.get('log_dir', '~/ray_cmbpo/')
        self.agent = CPOAgent(**cpo_kwargs)
        exp_name = 'cpo'
        test_seed = random.randint(0, 9999)
        #logger_kwargs = setup_logger_kwargs(exp_name, test_seed, data_dir=log_dir)
        if logger:
            self.logger = logger
        else:
            self.logger = EpochLogger()
            #self.logger.save_config(locals())
            self.agent.set_logger(self.logger)

        self.saver = Saver()
        self.sess = session
        self.agent.prepare_session(self.sess)
        self.act_space = act_space
        self.obs_space = obs_space
        self.ep_len = kwargs.get('epoch_length')

        # ___________________________________________ #
        #              Prepare ac network             #
        # ___________________________________________ #
        scope = 'AC'
        with tf.variable_scope(scope):
            # tf placeholders
            with tf.variable_scope('obs_ph'):
                self.obs_ph = placeholders_from_spaces(self.obs_space)[0]
            with tf.variable_scope('a_ph'):
                self.a_ph = placeholders_from_spaces(self.act_space)[0]

            # input placeholders to computation graph for batch data
            with tf.variable_scope('adv_ph'):
                self.adv_ph = placeholder(None)
            with tf.variable_scope('cadv_ph'):
                self.cadv_ph = placeholder(None)

            with tf.variable_scope('logp_old_ph'):
                self.logp_old_ph = placeholder(None)
            with tf.variable_scope('surr_cost_rescale_ph'):
                # phs for cpo specific inputs to comp graph
                self.surr_cost_rescale_ph = placeholder(None)
            with tf.variable_scope('cur_cost_ph'):
                self.cur_cost_ph = placeholder(None)

            # critic phs
            with tf.variable_scope('ret_ph'):
                self.ret_ph = placeholder(None)
            with tf.variable_scope('cret_ph'):
                self.cret_ph = placeholder(None)
            with tf.variable_scope('old_v_ph'):
                self.old_v_ph = placeholder(None)
            with tf.variable_scope('old_vc_ph'):
                self.old_vc_ph = placeholder(None)
            with tf.variable_scope('old_v_var_ph'):
                self.old_v_var_ph = placeholder(None)
            with tf.variable_scope('old_vc_var_ph'):
                self.old_vc_var_ph = placeholder(None)

            #### _________________________________ ####
            ####            Create Actor           ####
            #### _________________________________ ####
            # kwargs for ac network
            a_kwargs = dict()
            a_kwargs['action_space'] = self.act_space
            a_kwargs['hidden_sizes'] = self.hidden_sizes_a
            a_kwargs['ensemble_size_3d'] = self.dyn_ensemble_size

            self.actor = mlp_actor

            actor_outs = self.actor(self.obs_ph, self.a_ph, **a_kwargs)
            if self.dyn_ensemble_size == 1:
                self.pi, self.logp, self.logp_pi, self.pi_info, self.pi_info_phs, self.d_kl, self.ent \
                    = actor_outs
            else:
                self.pi, self.logp, self.logp_pi, self.pi_info, self.pi_info_phs, self.d_kl, self.ent, \
                    self.pi_3d, self.logp_3d, self.logp_pi_3d, self.pi_info_3d = actor_outs

            #### _________________________________ ####
            ####       Create Critic (Ensemble)    ####
            #### _________________________________ ####

            vf_kwargs = dict()
            vf_kwargs['in_dim'] = np.prod(self.obs_space.shape)
            vf_kwargs['out_dim'] = 1
            vf_kwargs['hidden_dims'] = self.hidden_sizes_c
            vf_kwargs['lr'] = self.vf_lr
            vf_kwargs['num_networks'] = self.vf_ensemble
            vf_kwargs['activation'] = self.vf_activation
            vf_kwargs['loss'] = self.vf_loss
            vf_kwargs['decay'] = self.vf_decay
            vf_kwargs['clip_loss'] = self.vf_cliploss
            vf_kwargs['var_corr'] = self.vf_var_corr
            vf_kwargs['num_elites'] = self.vf_elites
            vf_kwargs['use_scaler_in'] = True
            vf_kwargs['use_scaler_out'] = True
            # vf_kwargs['sc_factor']      = 1e3
            vf_kwargs['session'] = self.sess

            self.v = construct_model(name='VEnsemble',
                                     max_logvar=-2,
                                     min_logvar=-10,
                                     logit_bias_std=self.v_logit_bias,
                                     **vf_kwargs)
            self.vc = construct_model(name='VCEnsemble',
                                      max_logvar=5,
                                      min_logvar=-10,
                                      logit_bias_std=self.vc_logit_bias,
                                      **vf_kwargs)

            # Organize placeholders for zipping with data from buffer on updates
            # careful ! this has to be in sync with the output of our buffer !
            self.buf_fields = [
                self.obs_ph, self.a_ph, self.adv_ph, 'ret_var', self.cadv_ph,
                'cret_var', self.ret_ph, self.cret_ph, self.logp_old_ph,
                self.old_v_ph, self.old_v_var_ph, self.old_vc_ph,
                self.old_vc_var_ph, self.cur_cost_ph
            ] + values_as_sorted_list(self.pi_info_phs)

            self.actor_phs = [
                self.obs_ph,
                self.a_ph,
                self.adv_ph,
                self.cadv_ph,
                self.logp_old_ph,
                self.cret_ph,
                self.cur_cost_ph,
            ] + values_as_sorted_list(self.pi_info_phs)

            self.critic_phs = [
                self.obs_ph,
                self.ret_ph,
                'ret_var',
                self.cret_ph,
                'cret_var',
                self.old_v_ph,
                self.old_v_var_ph,
                self.old_vc_ph,
                self.old_vc_var_ph,
            ]

            self.actor_fd = lambda x: {k: x[k] for k in self.actor_phs}
            self.critic_fd = lambda x: {k: x[k] for k in self.critic_phs}

            # organize tf ops required for generation of actions
            self.ops_for_action = dict(pi=self.pi,
                                       logp_pi=self.logp_pi,
                                       pi_info=self.pi_info)

            if self.dyn_ensemble_size > 1:
                self.ops_for_action_3d = dict(pi=self.pi_3d,
                                              logp_pi=self.logp_pi_3d,
                                              pi_info=self.pi_info_3d)

            # organize tf ops for diagnostics
            self.ops_for_diagnostics = dict(
                pi=self.pi,
                logp_pi=self.logp_pi,
                pi_info=self.pi_info,
            )

            # Count variables
            var_counts = tuple(
                count_vars(scope)
                for scope in ['pi', 'VEnsemble', 'VCEnsemble'])
            self.logger.log(
                '\nNumber of parameters: \t pi: %d, \t v: %d, \t vc: %d\n' %
                var_counts)

            # Make a sample estimate for entropy to use as sanity check
            #approx_ent = tf.reduce_mean(-self.logp)

            # ________________________________ #
            #    Computation graph for policy  #
            # ________________________________ #

            ratio = tf.exp(self.logp - self.logp_old_ph)

            # Surrogate advantage / clipped surrogate advantage
            self.surr_adv = tf.reduce_mean(
                ratio * self.adv_ph
            )  #* (1/self.adv_var_ph)) / (tf.reduce_sum(1/self.adv_var_ph))

            # Surrogate cost (advantage)
            self.surr_cost = tf.reduce_mean(
                ratio * self.cadv_ph
            )  # * 1/(self.cadv_var_ph)) / (tf.reduce_sum(1/self.cadv_var_ph))

            # Current Cret
            ### either cost-based or based on td returns
            # self.cur_cret_avg = tf.reduce_mean(self.cret_ph)*self.max_path_length*(1-self.c_gamma)
            self.cur_cret_avg = tf.reduce_mean(
                self.cur_cost_ph) * self.max_path_length

            # cost_lim if it were discounted
            # self.disc_cost_lim = (self.cost_lim/self.ep_len)

            # Create policy objective function, including entropy regularization
            pi_objective = self.surr_adv + self.ent_reg * self.ent

            # Loss function for pi is negative of pi_objective
            self.pi_loss = -pi_objective

            # Optimizer-specific symbols
            if self.agent.trust_region:  ### <------- CPO
                # Symbols needed for CG solver for any trust region method
                pi_params = get_vars('pi')
                flat_g = tro.flat_grad(self.pi_loss, pi_params)
                v_ph, hvp = tro.hessian_vector_product(self.d_kl, pi_params)
                if self.agent.damping_coeff > 0:
                    hvp += self.agent.damping_coeff * v_ph

                # Symbols needed for CG solver for CPO only
                flat_b = tro.flat_grad(self.surr_cost, pi_params)

                # Symbols for getting and setting params
                get_pi_params = tro.flat_concat(pi_params)
                set_pi_params = tro.assign_params_from_flat(v_ph, pi_params)

                self.training_package = dict(flat_g=flat_g,
                                             flat_b=flat_b,
                                             v_ph=v_ph,
                                             hvp=hvp,
                                             get_pi_params=get_pi_params,
                                             set_pi_params=set_pi_params)

            else:
                raise NotImplementedError

            # Provide training package to agent
            self.training_package.update(
                dict(pi_loss=self.pi_loss,
                     surr_cost=self.surr_cost,
                     cur_cret_avg=self.cur_cret_avg,
                     d_kl=self.d_kl,
                     target_kl=self.target_kl,
                     cost_lim=self.cost_lim,
                     real_cost_buf=self.real_c_buffer))
            self.agent.prepare_update(self.training_package)

        ##### set up saver after all graph building is done
        self.saver.init_saver(scope=scope)
예제 #5
0
파일: mbpo.py 프로젝트: melfm/mbpo
    def __init__(
        self,
        training_environment,
        evaluation_environment,
        policy,
        Qs,
        pool,
        static_fns,
        plotter=None,
        tf_summaries=False,
        lr=3e-4,
        reward_scale=1.0,
        target_entropy='auto',
        discount=0.99,
        tau=5e-3,
        target_update_interval=1,
        action_prior='uniform',
        reparameterize=False,
        store_extra_policy_info=False,
        deterministic=False,
        model_train_freq=250,
        num_networks=7,
        num_elites=5,
        model_retain_epochs=20,
        rollout_batch_size=1e3,
        real_ratio=0.1,
        rollout_schedule=[20, 100, 1, 1],
        hidden_dim=200,
        max_model_t=None,
        shape_reward=False,
        max_action=1.0,
        **kwargs,
    ):
        """
        Args:
            env (`SoftlearningEnv`): Environment used for training.
            policy: A policy function approximator.
            initial_exploration_policy: ('Policy'): A policy that we use
                for initial exploration which is not trained by the algorithm.
            Qs: Q-function approximators. The min of these
                approximators will be used. Usage of at least two Q-functions
                improves performance by reducing overestimation bias.
            pool (`PoolBase`): Replay pool to add gathered samples to.
            plotter (`QFPolicyPlotter`): Plotter instance to be used for
                visualizing Q-function during training.
            lr (`float`): Learning rate used for the function approximators.
            discount (`float`): Discount factor for Q-function updates.
            tau (`float`): Soft value function target update weight.
            target_update_interval ('int'): Frequency at which target network
                updates occur in iterations.
            reparameterize ('bool'): If True, we use a gradient estimator for
                the policy derived using the reparameterization trick. We use
                a likelihood ratio based estimator otherwise.
        """
        super(MBPO, self).__init__(**kwargs)
        # for regular gym env
        #obs_dim = np.prod(training_environment.observation_space.shape)
        # for yuchen's modified env
        obs_dim = np.prod(
            training_environment.observation_space['observation'].shape)
        act_dim = np.prod(training_environment.action_space.shape)
        self.obs_dim_tup = training_environment.observation_space[
            'observation'].shape
        self.act_dim_tup = training_environment.action_space.shape
        self._model = construct_model(obs_dim=obs_dim,
                                      act_dim=act_dim,
                                      hidden_dim=hidden_dim,
                                      num_networks=num_networks,
                                      num_elites=num_elites)
        self._static_fns = static_fns
        self.fake_env = FakeEnv(self._model, self._static_fns)

        self._rollout_schedule = rollout_schedule
        self._max_model_t = max_model_t

        # self._model_pool_size = model_pool_size
        # print('[ MBPO ] Model pool size: {:.2E}'.format(self._model_pool_size))
        # self._model_pool = SimpleReplayPool(pool._observation_space, pool._action_space, self._model_pool_size)

        self._model_retain_epochs = model_retain_epochs

        self._model_train_freq = model_train_freq
        self._rollout_batch_size = int(rollout_batch_size)
        self._deterministic = deterministic
        self._real_ratio = real_ratio

        self._log_dir = os.getcwd()
        self._writer = Writer(self._log_dir)

        self._training_environment = training_environment
        self._evaluation_environment = evaluation_environment
        self._policy = policy

        self._Qs = Qs
        self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs)

        self._pool = pool
        # TODO: Fix hard-coded path
        # Only do this if we are shaping the reward
        print("Are we shaping the reward: {0}".format(
            shape_reward))  #TODO: remove this line once debugging is done
        if (shape_reward):
            demo_data = np.load("./mbpo/demonstration_data/demo_data_old.npz")
            # The demo data needs the next observations
            # TODO : Fix the skip last trajectory. The data should contain separate
            # observations and next_observations.
            samples = {
                'observations': demo_data["o"].reshape(-1, 6)[:-40],
                'actions': demo_data["u"].reshape(-1, 4),
                'next_observations': demo_data["o"].reshape(-1, 6)[:-40],
                'rewards': demo_data["r"].reshape(-1, 1),
                'terminals': demo_data["done"].reshape(-1, 1)
            }
            self._demo_pool = SimpleReplayPool(
                pool._observation_space['observation'], pool._action_space,
                pool._max_size)
            self._demo_pool.add_samples(samples)

        self._plotter = plotter
        self._tf_summaries = tf_summaries

        self._policy_lr = lr
        self._Q_lr = lr

        self._reward_scale = reward_scale
        self._target_entropy = (
            -np.prod(self._training_environment.action_space.shape)
            if target_entropy == 'auto' else target_entropy)
        print('[ MBPO ] Target entropy: {}'.format(self._target_entropy))

        self._discount = discount
        self._tau = tau
        self._target_update_interval = target_update_interval
        self._action_prior = action_prior

        self._reparameterize = reparameterize
        self._store_extra_policy_info = store_extra_policy_info

        observation_shape = self._training_environment.active_observation_shape
        action_shape = self._training_environment.action_space.shape

        assert len(observation_shape) == 1, observation_shape
        self._observation_shape = observation_shape
        assert len(action_shape) == 1, action_shape
        self._action_shape = action_shape

        self.shape_reward = shape_reward
        self.max_action = max_action

        self._build()