コード例 #1
0
    def _build(self):
        variant = copy.deepcopy(self._variant)

        env = self.env = get_environment_from_variant(variant)
        replay_pool = self.replay_pool = (get_replay_pool_from_variant(
            variant, env))
        sampler = self.sampler = get_sampler_from_variant(variant)
        Qs = self.Qs = get_Q_function_from_variant(variant, env)
        policy = self.policy = get_policy_from_variant(variant, env, Qs)
        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy('UniformPolicy', env))

        self.algorithm = get_algorithm_from_variant(
            variant=variant,
            env=env,
            policy=policy,
            initial_exploration_policy=initial_exploration_policy,
            Qs=Qs,
            pool=replay_pool,
            sampler=sampler,
            session=self._session,
        )

        initialize_tf_variables(self._session, only_uninitialized=True)

        self._built = True
コード例 #2
0
ファイル: main.py プロジェクト: Haffon/synergyDRL
    def _build(self):
        variant = copy.deepcopy(self._variant)

        env = self.env = get_environment_from_variant(variant)
        replay_pool = self.replay_pool = (get_replay_pool_from_variant(
            variant, env))
        sampler = self.sampler = get_sampler_from_variant(variant)
        Qs = self.Qs = get_Q_function_from_variant(variant, env)
        #policy = self.policy = get_policy_from_variant(variant, env, Qs)
        policy = self.policy = get_policy_from_variant(variant, env)

        initial_exploration_policy = self.initial_exploration_policy = (
            get_policy('UniformPolicy', env))

        self.algorithm = get_algorithm_from_variant(
            variant=self._variant,
            env=self.env,
            policy=policy,
            initial_exploration_policy=initial_exploration_policy,
            Qs=Qs,
            pool=replay_pool,
            sampler=sampler,
            session=self._session)

        print([
            x for x in tf.get_default_graph().get_operations()
            if x.type == "Placeholder"
        ])
        initialize_tf_variables(self._session, only_uninitialized=True)

        self._built = True