コード例 #1
0
  def __init__(self, env_str, distinct=1, count=1, seeds=None):
    self.distinct = distinct
    self.count = count
    self.total = self.distinct * self.count
    self.seeds = seeds or [random.randint(0, 1e12)
                           for _ in xrange(self.distinct)]

    self.envs = []
    for seed in self.seeds:
      for _ in xrange(self.count):
        env = get_env(env_str)
        env.seed(seed)
        if hasattr(env, 'last'):
          env.last = 100  # for algorithmic envs
        self.envs.append(env)

    self.dones = [True] * self.total
    self.num_episodes_played = 0

    one_env = self.get_one()
    self.use_action_list = hasattr(one_env.action_space, 'spaces')
    self.env_spec = env_spec.EnvSpec(self.get_one())
コード例 #2
0
ファイル: trainer.py プロジェクト: GD06/models
    def __init__(self):
        self.batch_size = FLAGS.batch_size
        self.replay_batch_size = FLAGS.replay_batch_size
        if self.replay_batch_size is None:
            self.replay_batch_size = self.batch_size
        self.num_samples = FLAGS.num_samples

        self.env_str = FLAGS.env
        self.env = gym_wrapper.GymWrapper(self.env_str,
                                          distinct=FLAGS.batch_size //
                                          self.num_samples,
                                          count=self.num_samples)
        self.eval_env = gym_wrapper.GymWrapper(self.env_str,
                                               distinct=FLAGS.batch_size //
                                               self.num_samples,
                                               count=self.num_samples)
        self.env_spec = env_spec.EnvSpec(self.env.get_one())

        self.max_step = FLAGS.max_step
        self.cutoff_agent = FLAGS.cutoff_agent
        self.num_steps = FLAGS.num_steps
        self.validation_frequency = FLAGS.validation_frequency

        self.target_network_lag = FLAGS.target_network_lag
        self.sample_from = FLAGS.sample_from
        assert self.sample_from in ['online', 'target']

        self.critic_weight = FLAGS.critic_weight
        self.objective = FLAGS.objective
        self.trust_region_p = FLAGS.trust_region_p
        self.value_opt = FLAGS.value_opt
        assert not self.trust_region_p or self.objective in ['pcl', 'trpo']
        assert self.objective != 'trpo' or self.trust_region_p
        assert self.value_opt is None or self.value_opt == 'None' or \
            self.critic_weight == 0.0
        self.max_divergence = FLAGS.max_divergence

        self.learning_rate = FLAGS.learning_rate
        self.clip_norm = FLAGS.clip_norm
        self.clip_adv = FLAGS.clip_adv
        self.tau = FLAGS.tau
        self.tau_decay = FLAGS.tau_decay
        self.tau_start = FLAGS.tau_start
        self.eps_lambda = FLAGS.eps_lambda
        self.update_eps_lambda = FLAGS.update_eps_lambda
        self.gamma = FLAGS.gamma
        self.rollout = FLAGS.rollout
        self.use_target_values = FLAGS.use_target_values
        self.fixed_std = FLAGS.fixed_std
        self.input_prev_actions = FLAGS.input_prev_actions
        self.recurrent = FLAGS.recurrent
        assert not self.trust_region_p or not self.recurrent
        self.input_time_step = FLAGS.input_time_step
        assert not self.input_time_step or (self.cutoff_agent <= self.max_step)

        self.use_online_batch = FLAGS.use_online_batch
        self.batch_by_steps = FLAGS.batch_by_steps
        self.unify_episodes = FLAGS.unify_episodes
        if self.unify_episodes:
            assert self.batch_size == 1

        self.replay_buffer_size = FLAGS.replay_buffer_size
        self.replay_buffer_alpha = FLAGS.replay_buffer_alpha
        self.replay_buffer_freq = FLAGS.replay_buffer_freq
        assert self.replay_buffer_freq in [-1, 0, 1]
        self.eviction = FLAGS.eviction
        self.prioritize_by = FLAGS.prioritize_by
        assert self.prioritize_by in ['rewards', 'step']
        self.num_expert_paths = FLAGS.num_expert_paths

        self.internal_dim = FLAGS.internal_dim
        self.value_hidden_layers = FLAGS.value_hidden_layers
        self.tf_seed = FLAGS.tf_seed

        self.save_trajectories_dir = FLAGS.save_trajectories_dir
        self.save_trajectories_file = (os.path.join(
            self.save_trajectories_dir, self.env_str.replace('-', '_'))
                                       if self.save_trajectories_dir else None)
        self.load_trajectories_file = FLAGS.load_trajectories_file

        self.hparams = dict(
            (attr, getattr(self, attr)) for attr in dir(self)
            if not attr.startswith('__')
            and not isinstance(getattr(self, attr), collections.Callable))