コード例 #1
0
 def __init__(self, learner, oracle, policy, update_pol_nor=True,
              grad_std_n=10, grad_std_freq=None, log_sigmas_freq=None, log_sigmas_kwargs=None,
              gen_sim_ro=None, ** kwargs):
     '''
     learner is online_optimizer
     '''
     self._learner = learner
     self._or = safe_assign(oracle, Or.rlOracle)
     self._policy = safe_assign(policy, Policy)
     self._itr = 0
     self._update_pol_nor = update_pol_nor
     self.gen_sim_ro = gen_sim_ro
コード例 #2
0
 def __init__(self, env, alg, gen_ro):
     self._env = env
     self._alg = safe_assign(alg, Algorithm)
     self._gen_ro_raw = gen_ro
     self._gen_ro = functools.partial(gen_ro,
                                      pi=self._alg.pi_ro,
                                      logp=self._alg.logp)
     self._ndata = 0  # number of data points seen
コード例 #3
0
 def __init__(self, alg, env, ro_kwargs):
     """
     ro_kwargs is a dict with keys, 'min_n_samples', 'max_n_rollouts', 'max_rollout_len'
     """
     self._alg = safe_assign(alg, Algorithm)
     self._gen_ro = functools.partial(generate_rollout,
                                      env=env,
                                      **ro_kwargs)
     self._ndata = 0  # number of data points seen
コード例 #4
0
    def __init__(
            self,
            pcl,
            oracle,
            policy,
            model_oracle=None,
            update_rule='model-free',
            update_in_pred=False,  # if to update normalizer and ae in prediction
            take_first_pred=False,  # take a prediction step before the first correction step
            warm_start=True,  # if to use first-order piccolo to warm start the VI problem
            shift_adv=False,  # make the adv to be positive
            stop_std_grad=False,  # freeze std
            ignore_samples=False,  # ignore all the information from samples
    ):

        self._pcl = safe_assign(pcl, Piccolo, PiccoloOpt)
        # Saved in the correction step to be used in prediction step.
        self._or = safe_assign(oracle, Or.rlOracle)
        self._policy = safe_assign(policy, Policy)
        if model_oracle is not None:
            self._mor = safe_assign(model_oracle, Or.rlOracle)

        self._itr = 0
        self._ro = None  # rollouts of from the environment
        self._g = None

        # flags
        assert update_rule in ['piccolo', 'model-free', 'model-based', 'dyna']
        self._update_rule = update_rule
        self._w_pred = update_rule in ['piccolo', 'dyna', 'model-based']
        self._pre_w_adap = update_rule in ['dyna', 'model-based']
        self._w_corr = update_rule in ['model-free', 'piccolo', 'dyna']

        self._update_in_pred = update_in_pred
        self._take_first_pred = take_first_pred
        self._warm_start = warm_start
        self._shift_adv = shift_adv
        self._stop_std_grad = stop_std_grad
        self._ignore_samples = ignore_samples
コード例 #5
0
 def __init__(self, policy):
     self.policy = safe_assign(policy, Policy)