def __init__( self, training_environment, evaluation_environment, policy, Qs, pool, static_fns, plotter=None, tf_summaries=False, lr=3e-4, reward_scale=1.0, target_entropy='auto', discount=0.99, tau=5e-3, target_update_interval=1, action_prior='uniform', reparameterize=False, store_extra_policy_info=False, deterministic=False, model_train_freq=250, num_networks=7, num_elites=5, model_retain_epochs=20, load_model_dir=None, rollout_batch_size=100e3, real_ratio=0.1, rollout_schedule=[20, 100, 1, 1], hidden_dim=200, max_model_t=None, **kwargs, ): """ Args: env (`SoftlearningEnv`): Environment used for training. policy: A policy function approximator. initial_exploration_policy: ('Policy'): A policy that we use for initial exploration which is not trained by the algorithm. Qs: Q-function approximators. The min of these approximators will be used. Usage of at least two Q-functions improves performance by reducing overestimation bias. pool (`PoolBase`): Replay pool to add gathered samples to. plotter (`QFPolicyPlotter`): Plotter instance to be used for visualizing Q-function during training. lr (`float`): Learning rate used for the function approximators. discount (`float`): Discount factor for Q-function updates. tau (`float`): Soft value function target update weight. target_update_interval ('int'): Frequency at which target network updates occur in iterations. reparameterize ('bool'): If True, we use a gradient estimator for the policy derived using the reparameterization trick. We use a likelihood ratio based estimator otherwise. """ super(MBPO, self).__init__(**kwargs) obs_dim = np.prod(training_environment.observation_space.shape) act_dim = np.prod(training_environment.action_space.shape) self._model_params = dict(obs_dim=obs_dim, act_dim=act_dim, hidden_dim=hidden_dim, num_networks=num_networks, num_elites=num_elites) if load_model_dir is not None: self._model_params['load_model'] = True self._model_params['model_dir'] = load_model_dir self._model = construct_model(**self._model_params) self._static_fns = static_fns self.fake_env = FakeEnv(self._model, self._static_fns) self._rollout_schedule = rollout_schedule self._max_model_t = max_model_t self._model_retain_epochs = model_retain_epochs self._model_train_freq = model_train_freq self._rollout_batch_size = int(rollout_batch_size) self._deterministic = deterministic self._real_ratio = real_ratio self._log_dir = os.getcwd() self._writer = Writer(self._log_dir) self._training_environment = training_environment self._evaluation_environment = evaluation_environment self._policy = policy self._Qs = Qs self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs) self._pool = pool self._plotter = plotter self._tf_summaries = tf_summaries self._policy_lr = lr self._Q_lr = lr self._reward_scale = reward_scale self._target_entropy = ( -np.prod(self._training_environment.action_space.shape) if target_entropy == 'auto' else target_entropy) print('[ MBPO ] Target entropy: {}'.format(self._target_entropy)) self._discount = discount self._tau = tau self._target_update_interval = target_update_interval self._action_prior = action_prior self._reparameterize = reparameterize self._store_extra_policy_info = store_extra_policy_info observation_shape = self._training_environment.active_observation_shape action_shape = self._training_environment.action_space.shape ### @ anyboby fixed pool size, reallocate causes memory leak obs_space = self._pool._observation_space act_space = self._pool._action_space rollouts_per_epoch = self._rollout_batch_size * self._epoch_length / self._model_train_freq model_steps_per_epoch = int(self._rollout_schedule[-1] * rollouts_per_epoch) mpool_size = self._model_retain_epochs * model_steps_per_epoch self._model_pool = SimpleReplayPool(obs_space, act_space, mpool_size) assert len(observation_shape) == 1, observation_shape self._observation_shape = observation_shape assert len(action_shape) == 1, action_shape self._action_shape = action_shape self._build()
def __init__( self, training_environment, evaluation_environment, policy, Qs, pool, static_fns, plotter=None, tf_summaries=False, lr=3e-4, reward_scale=1.0, target_entropy='auto', discount=0.99, tau=5e-3, target_update_interval=1, action_prior='uniform', reparameterize=False, store_extra_policy_info=False, deterministic=False, model_train_freq=250, model_train_slower=1, num_networks=7, num_elites=5, num_Q_elites=2, # The num of Q ensemble is set in command line model_retain_epochs=20, rollout_batch_size=100e3, real_ratio=0.1, critic_same_as_actor=True, rollout_schedule=[20,100,1,1], hidden_dim=200, max_model_t=None, dir_name=None, evaluate_explore_freq=0, num_Q_per_grp=2, num_Q_grp=1, cross_grp_diff_batch=False, model_load_dir=None, model_load_index=None, model_log_freq=0, **kwargs, ): """ Args: env (`SoftlearningEnv`): Environment used for training. policy: A policy function approximator. initial_exploration_policy: ('Policy'): A policy that we use for initial exploration which is not trained by the algorithm. Qs: Q-function approximators. The min of these approximators will be used. Usage of at least two Q-functions improves performance by reducing overestimation bias. pool (`PoolBase`): Replay pool to add gathered samples to. plotter (`QFPolicyPlotter`): Plotter instance to be used for visualizing Q-function during training. lr (`float`): Learning rate used for the function approximators. discount (`float`): Discount factor for Q-function updates. tau (`float`): Soft value function target update weight. target_update_interval ('int'): Frequency at which target network updates occur in iterations. reparameterize ('bool'): If True, we use a gradient estimator for the policy derived using the reparameterization trick. We use a likelihood ratio based estimator otherwise. critic_same_as_actor ('bool'): If True, use the same sampling schema (model free or model based) as the actor in critic training. Otherwise, use model free sampling to train critic. """ super(MBPO, self).__init__(**kwargs) if training_environment.unwrapped.spec.id.find("Fetch") != -1: # Fetch env obs_dim = sum([i.shape[0] for i in training_environment.observation_space.spaces.values()]) self.multigoal = 1 else: obs_dim = np.prod(training_environment.observation_space.shape) # print("====", obs_dim, "========") act_dim = np.prod(training_environment.action_space.shape) # TODO: add variable scope to directly extract model parameters self._model_load_dir = model_load_dir print("============Model dir: ", self._model_load_dir) if model_load_index: latest_model_index = model_load_index else: latest_model_index = self._get_latest_index() self._model = construct_model(obs_dim=obs_dim, act_dim=act_dim, hidden_dim=hidden_dim, num_networks=num_networks, num_elites=num_elites, model_dir=self._model_load_dir, model_load_timestep=latest_model_index, load_model=True if model_load_dir else False) self._static_fns = static_fns self.fake_env = FakeEnv(self._model, self._static_fns) model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self._model.name) all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) self._rollout_schedule = rollout_schedule self._max_model_t = max_model_t # self._model_pool_size = model_pool_size # print('[ MBPO ] Model pool size: {:.2E}'.format(self._model_pool_size)) # self._model_pool = SimpleReplayPool(pool._observation_space, pool._action_space, self._model_pool_size) self._model_retain_epochs = model_retain_epochs self._model_train_freq = model_train_freq self._rollout_batch_size = int(rollout_batch_size) self._deterministic = deterministic self._real_ratio = real_ratio self._log_dir = os.getcwd() self._writer = Writer(self._log_dir) self._training_environment = training_environment self._evaluation_environment = evaluation_environment self._policy = policy self._Qs = Qs self._Q_ensemble = len(Qs) self._Q_elites = num_Q_elites self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs) self._pool = pool self._plotter = plotter self._tf_summaries = tf_summaries self._policy_lr = lr self._Q_lr = lr self._reward_scale = reward_scale self._target_entropy = ( -np.prod(self._training_environment.action_space.shape) if target_entropy == 'auto' else target_entropy) print('[ MBPO ] Target entropy: {}'.format(self._target_entropy)) self._discount = discount self._tau = tau self._target_update_interval = target_update_interval self._action_prior = action_prior self._reparameterize = reparameterize self._store_extra_policy_info = store_extra_policy_info observation_shape = self._training_environment.active_observation_shape action_shape = self._training_environment.action_space.shape assert len(observation_shape) == 1, observation_shape self._observation_shape = observation_shape assert len(action_shape) == 1, action_shape self._action_shape = action_shape # self._critic_train_repeat = kwargs["critic_train_repeat"] # actor UTD should be n times larger or smaller than critic UTD assert self._actor_train_repeat % self._critic_train_repeat == 0 or \ self._critic_train_repeat % self._actor_train_repeat == 0 self._critic_train_freq = self._n_train_repeat // self._critic_train_repeat self._actor_train_freq = self._n_train_repeat // self._actor_train_repeat self._critic_same_as_actor = critic_same_as_actor self._model_train_slower = model_train_slower self._origin_model_train_epochs = 0 self._dir_name = dir_name self._evaluate_explore_freq = evaluate_explore_freq # Inter-group Qs are trained with the same data; Cross-group Qs different. self._num_Q_per_grp = num_Q_per_grp self._num_Q_grp = num_Q_grp self._cross_grp_diff_batch = cross_grp_diff_batch self._model_log_freq = model_log_freq self._build()
def __init__(self, true_environment, static_fns, num_networks=7, num_elites=5, hidden_dims=(220, 220, 220), dyn_discount=1, cost_m_discount=1, cares_about_cost=False, session=None): self.domain = true_environment.domain if true_environment.domain else "" self.env = true_environment self.obs_dim = np.prod(self.observation_space.shape) self.act_dim = np.prod(self.action_space.shape) self.active_obs_dim = int(self.obs_dim / self.env.stacks) self._session = session self.cares_about_cost = cares_about_cost self.rew_dim = 1 self.cost_classes = [0, 1] self.num_networks = num_networks self.num_elites = num_elites self.static_fns = static_fns target_weight_f = WEIGHTS_PER_DOMAIN.get(self.domain, None) self.target_weights = target_weight_f( self.obs_dim) if target_weight_f else None self._static_fns = static_fns # termination functions for the envs (model can't simulate those) self.static_r = self.static_fns.reward_f if "reward_f" in dir( self.static_fns) else False self.static_done = self.static_fns.termination_fn if "termination_fn" in dir( self.static_fns) else False self.static_c = self.static_fns.cost_f if "cost_f" in dir( self.static_fns) else False self.post_f = self.static_fns.post_f if "post_f" in dir( self.static_fns) else False self.prior_f = self.static_fns.prior_f if "prior_f" in dir( self.static_fns) else False self.prior_dim = self.static_fns.PRIOR_DIM if "PRIOR_DIM" in dir( self.static_fns) else 0 #### create fake env from model input_dim_dyn = self.obs_dim + self.prior_dim + self.act_dim input_dim_c = 2 * self.obs_dim + self.act_dim + self.prior_dim output_dim_dyn = self.active_obs_dim + self.rew_dim self.dyn_loss = 'MSPE' self._model = construct_model(in_dim=input_dim_dyn, out_dim=output_dim_dyn, name='BNN', loss=self.dyn_loss, hidden_dims=hidden_dims, lr=1e-3, num_networks=num_networks, num_elites=num_elites, weighted=dyn_discount < 1, use_scaler_in=True, use_scaler_out=True, decay=1e-6, max_logvar=.5, min_logvar=-10, session=self._session) if self.cares_about_cost and not self.static_c: self.cost_m_loss = 'MSE' output_activation = 'softmax' if self.cost_m_loss == 'CE' else None self._cost_model = construct_model( in_dim=input_dim_c, out_dim=2 if self.cost_m_loss == 'CE' else 1, loss=self.cost_m_loss, name='CostNN', hidden_dims=(64, 64), lr=1e-4, output_activation=output_activation, num_networks=num_networks, num_elites=num_elites, weighted=cost_m_discount < 1, use_scaler_in=False, use_scaler_out=False, decay=1e-6, session=self._session) else: self._cost_model = None self.dyn_target_var_rm = 1
def __init__( self, obs_space, act_space, session, logger=None, *args, **kwargs, ): # ___________________________________________ # # Params # # ___________________________________________ # self.hidden_sizes_a = kwargs.get('a_hidden_layer_sizes') self.hidden_sizes_c = kwargs.get('vf_hidden_layer_sizes') self.dyn_ensemble_size = kwargs.get('dyn_ensemble_size', 1) self.vf_lr = kwargs.get('vf_lr', 1e-4) self.vf_epochs = kwargs.get('vf_epochs', 10) self.vf_batch_size = kwargs.get('vf_batch_size', 64) self.vf_holdout = 0.1 self.vf_train_kwargs = dict(batch_size=self.vf_batch_size, min_epoch_before_break=self.vf_epochs, max_epochs=self.vf_epochs, holdout_ratio=self.vf_holdout) self.vf_ensemble = kwargs.get('vf_ensemble_size', 5) self.vf_elites = kwargs.get('vf_elites', 3) self.v_logit_bias = (kwargs.get('v_logit_bias', 0)) self.vc_logit_bias = kwargs.get('vc_logit_bias', 0) self.vf_activation = kwargs.get('vf_activation', 'ReLU') self.vf_loss = kwargs.get('vf_loss', 'MSE') self.vf_decay = kwargs.get('vf_decay', 1e-6) self.gaussian_vf = self.vf_loss == 'NLL' self.vf_cliploss = kwargs.get('vf_clipping', False) self.vf_cliprange = kwargs.get('vf_cliprange', 0.1) self.cvf_cliprange = kwargs.get('cvf_cliprange', 0.1) self.vf_var_corr = kwargs.get('vf_var_corr', False) self.ent_reg = kwargs.get('ent_reg', 0.0) self.cost_lim = kwargs.get('cost_lim', 25) self.real_c_buffer = [self.cost_lim] * 300 self.target_kl = kwargs.get('target_kl', 0.01) self.cost_lam = kwargs.get('cost_lam', 0.97) self.cost_gamma = kwargs.get('cost_gamma', 0.99) self.lam = kwargs.get('lam', 0.97) self.gamma = kwargs.get('discount', 0.99) self.max_path_length = kwargs.get('max_path_length', 1) #usually not deterministic, but give the option for eval runs self._deterministic = False cpo_kwargs = dict( reward_penalized=False, # Irrelevant in CPO objective_penalized=False, # Irrelevant in CPO learn_penalty=False, # Irrelevant in CPO penalty_param_loss=False, # Irrelevant in CPO learn_margin=True, #learn_margin=True, c_gamma=self.cost_gamma, max_path_length=self.max_path_length) # ________________________________ # # Cpo agent and logger # # ________________________________ # log_dir = kwargs.get('log_dir', '~/ray_cmbpo/') self.agent = CPOAgent(**cpo_kwargs) exp_name = 'cpo' test_seed = random.randint(0, 9999) #logger_kwargs = setup_logger_kwargs(exp_name, test_seed, data_dir=log_dir) if logger: self.logger = logger else: self.logger = EpochLogger() #self.logger.save_config(locals()) self.agent.set_logger(self.logger) self.saver = Saver() self.sess = session self.agent.prepare_session(self.sess) self.act_space = act_space self.obs_space = obs_space self.ep_len = kwargs.get('epoch_length') # ___________________________________________ # # Prepare ac network # # ___________________________________________ # scope = 'AC' with tf.variable_scope(scope): # tf placeholders with tf.variable_scope('obs_ph'): self.obs_ph = placeholders_from_spaces(self.obs_space)[0] with tf.variable_scope('a_ph'): self.a_ph = placeholders_from_spaces(self.act_space)[0] # input placeholders to computation graph for batch data with tf.variable_scope('adv_ph'): self.adv_ph = placeholder(None) with tf.variable_scope('cadv_ph'): self.cadv_ph = placeholder(None) with tf.variable_scope('logp_old_ph'): self.logp_old_ph = placeholder(None) with tf.variable_scope('surr_cost_rescale_ph'): # phs for cpo specific inputs to comp graph self.surr_cost_rescale_ph = placeholder(None) with tf.variable_scope('cur_cost_ph'): self.cur_cost_ph = placeholder(None) # critic phs with tf.variable_scope('ret_ph'): self.ret_ph = placeholder(None) with tf.variable_scope('cret_ph'): self.cret_ph = placeholder(None) with tf.variable_scope('old_v_ph'): self.old_v_ph = placeholder(None) with tf.variable_scope('old_vc_ph'): self.old_vc_ph = placeholder(None) with tf.variable_scope('old_v_var_ph'): self.old_v_var_ph = placeholder(None) with tf.variable_scope('old_vc_var_ph'): self.old_vc_var_ph = placeholder(None) #### _________________________________ #### #### Create Actor #### #### _________________________________ #### # kwargs for ac network a_kwargs = dict() a_kwargs['action_space'] = self.act_space a_kwargs['hidden_sizes'] = self.hidden_sizes_a a_kwargs['ensemble_size_3d'] = self.dyn_ensemble_size self.actor = mlp_actor actor_outs = self.actor(self.obs_ph, self.a_ph, **a_kwargs) if self.dyn_ensemble_size == 1: self.pi, self.logp, self.logp_pi, self.pi_info, self.pi_info_phs, self.d_kl, self.ent \ = actor_outs else: self.pi, self.logp, self.logp_pi, self.pi_info, self.pi_info_phs, self.d_kl, self.ent, \ self.pi_3d, self.logp_3d, self.logp_pi_3d, self.pi_info_3d = actor_outs #### _________________________________ #### #### Create Critic (Ensemble) #### #### _________________________________ #### vf_kwargs = dict() vf_kwargs['in_dim'] = np.prod(self.obs_space.shape) vf_kwargs['out_dim'] = 1 vf_kwargs['hidden_dims'] = self.hidden_sizes_c vf_kwargs['lr'] = self.vf_lr vf_kwargs['num_networks'] = self.vf_ensemble vf_kwargs['activation'] = self.vf_activation vf_kwargs['loss'] = self.vf_loss vf_kwargs['decay'] = self.vf_decay vf_kwargs['clip_loss'] = self.vf_cliploss vf_kwargs['var_corr'] = self.vf_var_corr vf_kwargs['num_elites'] = self.vf_elites vf_kwargs['use_scaler_in'] = True vf_kwargs['use_scaler_out'] = True # vf_kwargs['sc_factor'] = 1e3 vf_kwargs['session'] = self.sess self.v = construct_model(name='VEnsemble', max_logvar=-2, min_logvar=-10, logit_bias_std=self.v_logit_bias, **vf_kwargs) self.vc = construct_model(name='VCEnsemble', max_logvar=5, min_logvar=-10, logit_bias_std=self.vc_logit_bias, **vf_kwargs) # Organize placeholders for zipping with data from buffer on updates # careful ! this has to be in sync with the output of our buffer ! self.buf_fields = [ self.obs_ph, self.a_ph, self.adv_ph, 'ret_var', self.cadv_ph, 'cret_var', self.ret_ph, self.cret_ph, self.logp_old_ph, self.old_v_ph, self.old_v_var_ph, self.old_vc_ph, self.old_vc_var_ph, self.cur_cost_ph ] + values_as_sorted_list(self.pi_info_phs) self.actor_phs = [ self.obs_ph, self.a_ph, self.adv_ph, self.cadv_ph, self.logp_old_ph, self.cret_ph, self.cur_cost_ph, ] + values_as_sorted_list(self.pi_info_phs) self.critic_phs = [ self.obs_ph, self.ret_ph, 'ret_var', self.cret_ph, 'cret_var', self.old_v_ph, self.old_v_var_ph, self.old_vc_ph, self.old_vc_var_ph, ] self.actor_fd = lambda x: {k: x[k] for k in self.actor_phs} self.critic_fd = lambda x: {k: x[k] for k in self.critic_phs} # organize tf ops required for generation of actions self.ops_for_action = dict(pi=self.pi, logp_pi=self.logp_pi, pi_info=self.pi_info) if self.dyn_ensemble_size > 1: self.ops_for_action_3d = dict(pi=self.pi_3d, logp_pi=self.logp_pi_3d, pi_info=self.pi_info_3d) # organize tf ops for diagnostics self.ops_for_diagnostics = dict( pi=self.pi, logp_pi=self.logp_pi, pi_info=self.pi_info, ) # Count variables var_counts = tuple( count_vars(scope) for scope in ['pi', 'VEnsemble', 'VCEnsemble']) self.logger.log( '\nNumber of parameters: \t pi: %d, \t v: %d, \t vc: %d\n' % var_counts) # Make a sample estimate for entropy to use as sanity check #approx_ent = tf.reduce_mean(-self.logp) # ________________________________ # # Computation graph for policy # # ________________________________ # ratio = tf.exp(self.logp - self.logp_old_ph) # Surrogate advantage / clipped surrogate advantage self.surr_adv = tf.reduce_mean( ratio * self.adv_ph ) #* (1/self.adv_var_ph)) / (tf.reduce_sum(1/self.adv_var_ph)) # Surrogate cost (advantage) self.surr_cost = tf.reduce_mean( ratio * self.cadv_ph ) # * 1/(self.cadv_var_ph)) / (tf.reduce_sum(1/self.cadv_var_ph)) # Current Cret ### either cost-based or based on td returns # self.cur_cret_avg = tf.reduce_mean(self.cret_ph)*self.max_path_length*(1-self.c_gamma) self.cur_cret_avg = tf.reduce_mean( self.cur_cost_ph) * self.max_path_length # cost_lim if it were discounted # self.disc_cost_lim = (self.cost_lim/self.ep_len) # Create policy objective function, including entropy regularization pi_objective = self.surr_adv + self.ent_reg * self.ent # Loss function for pi is negative of pi_objective self.pi_loss = -pi_objective # Optimizer-specific symbols if self.agent.trust_region: ### <------- CPO # Symbols needed for CG solver for any trust region method pi_params = get_vars('pi') flat_g = tro.flat_grad(self.pi_loss, pi_params) v_ph, hvp = tro.hessian_vector_product(self.d_kl, pi_params) if self.agent.damping_coeff > 0: hvp += self.agent.damping_coeff * v_ph # Symbols needed for CG solver for CPO only flat_b = tro.flat_grad(self.surr_cost, pi_params) # Symbols for getting and setting params get_pi_params = tro.flat_concat(pi_params) set_pi_params = tro.assign_params_from_flat(v_ph, pi_params) self.training_package = dict(flat_g=flat_g, flat_b=flat_b, v_ph=v_ph, hvp=hvp, get_pi_params=get_pi_params, set_pi_params=set_pi_params) else: raise NotImplementedError # Provide training package to agent self.training_package.update( dict(pi_loss=self.pi_loss, surr_cost=self.surr_cost, cur_cret_avg=self.cur_cret_avg, d_kl=self.d_kl, target_kl=self.target_kl, cost_lim=self.cost_lim, real_cost_buf=self.real_c_buffer)) self.agent.prepare_update(self.training_package) ##### set up saver after all graph building is done self.saver.init_saver(scope=scope)
def __init__( self, training_environment, evaluation_environment, policy, Qs, pool, static_fns, plotter=None, tf_summaries=False, lr=3e-4, reward_scale=1.0, target_entropy='auto', discount=0.99, tau=5e-3, target_update_interval=1, action_prior='uniform', reparameterize=False, store_extra_policy_info=False, deterministic=False, model_train_freq=250, num_networks=7, num_elites=5, model_retain_epochs=20, rollout_batch_size=1e3, real_ratio=0.1, rollout_schedule=[20, 100, 1, 1], hidden_dim=200, max_model_t=None, shape_reward=False, max_action=1.0, **kwargs, ): """ Args: env (`SoftlearningEnv`): Environment used for training. policy: A policy function approximator. initial_exploration_policy: ('Policy'): A policy that we use for initial exploration which is not trained by the algorithm. Qs: Q-function approximators. The min of these approximators will be used. Usage of at least two Q-functions improves performance by reducing overestimation bias. pool (`PoolBase`): Replay pool to add gathered samples to. plotter (`QFPolicyPlotter`): Plotter instance to be used for visualizing Q-function during training. lr (`float`): Learning rate used for the function approximators. discount (`float`): Discount factor for Q-function updates. tau (`float`): Soft value function target update weight. target_update_interval ('int'): Frequency at which target network updates occur in iterations. reparameterize ('bool'): If True, we use a gradient estimator for the policy derived using the reparameterization trick. We use a likelihood ratio based estimator otherwise. """ super(MBPO, self).__init__(**kwargs) # for regular gym env #obs_dim = np.prod(training_environment.observation_space.shape) # for yuchen's modified env obs_dim = np.prod( training_environment.observation_space['observation'].shape) act_dim = np.prod(training_environment.action_space.shape) self.obs_dim_tup = training_environment.observation_space[ 'observation'].shape self.act_dim_tup = training_environment.action_space.shape self._model = construct_model(obs_dim=obs_dim, act_dim=act_dim, hidden_dim=hidden_dim, num_networks=num_networks, num_elites=num_elites) self._static_fns = static_fns self.fake_env = FakeEnv(self._model, self._static_fns) self._rollout_schedule = rollout_schedule self._max_model_t = max_model_t # self._model_pool_size = model_pool_size # print('[ MBPO ] Model pool size: {:.2E}'.format(self._model_pool_size)) # self._model_pool = SimpleReplayPool(pool._observation_space, pool._action_space, self._model_pool_size) self._model_retain_epochs = model_retain_epochs self._model_train_freq = model_train_freq self._rollout_batch_size = int(rollout_batch_size) self._deterministic = deterministic self._real_ratio = real_ratio self._log_dir = os.getcwd() self._writer = Writer(self._log_dir) self._training_environment = training_environment self._evaluation_environment = evaluation_environment self._policy = policy self._Qs = Qs self._Q_targets = tuple(tf.keras.models.clone_model(Q) for Q in Qs) self._pool = pool # TODO: Fix hard-coded path # Only do this if we are shaping the reward print("Are we shaping the reward: {0}".format( shape_reward)) #TODO: remove this line once debugging is done if (shape_reward): demo_data = np.load("./mbpo/demonstration_data/demo_data_old.npz") # The demo data needs the next observations # TODO : Fix the skip last trajectory. The data should contain separate # observations and next_observations. samples = { 'observations': demo_data["o"].reshape(-1, 6)[:-40], 'actions': demo_data["u"].reshape(-1, 4), 'next_observations': demo_data["o"].reshape(-1, 6)[:-40], 'rewards': demo_data["r"].reshape(-1, 1), 'terminals': demo_data["done"].reshape(-1, 1) } self._demo_pool = SimpleReplayPool( pool._observation_space['observation'], pool._action_space, pool._max_size) self._demo_pool.add_samples(samples) self._plotter = plotter self._tf_summaries = tf_summaries self._policy_lr = lr self._Q_lr = lr self._reward_scale = reward_scale self._target_entropy = ( -np.prod(self._training_environment.action_space.shape) if target_entropy == 'auto' else target_entropy) print('[ MBPO ] Target entropy: {}'.format(self._target_entropy)) self._discount = discount self._tau = tau self._target_update_interval = target_update_interval self._action_prior = action_prior self._reparameterize = reparameterize self._store_extra_policy_info = store_extra_policy_info observation_shape = self._training_environment.active_observation_shape action_shape = self._training_environment.action_space.shape assert len(observation_shape) == 1, observation_shape self._observation_shape = observation_shape assert len(action_shape) == 1, action_shape self._action_shape = action_shape self.shape_reward = shape_reward self.max_action = max_action self._build()