예제 #1
0
    def __init__(self, model_info):
        model_config = model_info.get('model_config', None)
        import_config(globals(), model_config)

        self.state_dim = model_info['state_dim']
        self.action_dim = model_info['action_dim']
        super().__init__(model_info)
예제 #2
0
    def __init__(self, model_info):
        model_config = model_info.get("model_config", dict())
        import_config(globals(), model_config)
        self.dtype = DTYPE_MAP.get(model_config.get("dtype", "float32"))

        self.state_dim = model_info["state_dim"]
        self.action_dim = model_info["action_dim"]
        self.filter_arch = get_atari_filter(self.state_dim)

        # lr schedule with linear_cosine_decay
        self.lr_schedule = model_info.get("lr_schedule", None)
        self.opt_type = model_info.get("opt_type", "adam")
        self.lr = None

        self.ph_state = None
        self.ph_adv = None
        self.out_actions = None
        self.pi_logic_outs, self.baseline = None, None

        # placeholder for behavior policy logic outputs
        self.ph_bp_logic_outs = None
        self.ph_actions = None
        self.ph_dones = None
        self.ph_rewards = None
        self.loss, self.optimizer, self.train_op = None, None, None

        self.grad_norm_clip = 40.0
        self.sample_batch_steps = 50

        self.saver = None
        self.explore_paras = None
        self.actor_var = None  # store weights for agent

        super().__init__(model_info)
예제 #3
0
파일: dqn.py 프로젝트: zeta1999/xingtian
    def __init__(self, model_info, alg_config, **kwargs):
        """Initialize DQN algorithm. it's contains four steps:
        1. override the default config, with user's configuration;
        2. create the default actor with Algorithm.__init__;
        3. create once more actor, named by target_actor;
        4. create the replay buffer for training.
        :param model_info:
        :param alg_config:
        """
        import_config(globals(), alg_config)
        model_info = model_info["actor"]
        super(DQN, self).__init__(alg_name="dqn",
                                  model_info=model_info,
                                  alg_config=alg_config)

        self.target_actor = model_builder(model_info)
        self.buff = ReplayBuffer(BUFFER_SIZE)
예제 #4
0
    def __init__(self, model_info, alg_config, **kwargs):
        import_config(globals(), alg_config)
        super().__init__(alg_name="impala",
                         model_info=model_info["actor"],
                         alg_config=alg_config)

        self.dummy_action, self.dummy_value = (
            np.zeros((1, self.action_dim)),
            np.zeros((1, 1)),
        )

        self.async_flag = False  # fixme: refactor async_flag
        self.episode_len = alg_config.get("episode_len", 128)

        self.dist_model_policy = FIFODistPolicy(
            alg_config["instance_num"],
            prepare_times=self._prepare_times_per_train)

        self._init_train_list()
예제 #5
0
파일: ppo.py 프로젝트: zeta1999/xingtian
    def __init__(self, model_info, alg_config, **kwargs):
        """
        Algorithm instance, will create their model within the `__init__`.
        :param model_info:
        :param alg_config:
        :param kwargs:
        """
        import_config(globals(), alg_config)
        super(PPO, self).__init__(
            alg_name=kwargs.get("name") or "ppo",
            model_info=model_info["actor"],
            alg_config=alg_config,
        )

        self._init_train_list()
        self.async_flag = False  # fixme: refactor async_flag

        if model_info.get("finetune_weight"):
            self.actor.load_model(model_info["finetune_weight"], by_name=True)
            print("load finetune weight: ", model_info["finetune_weight"])
예제 #6
0
    def __init__(self, model_info, alg_config, **kwargs):
        import_config(globals(), alg_config)
        super().__init__(alg_name="impala",
                         model_info=model_info["actor"],
                         alg_config=alg_config)
        self.states = list()
        self.behavior_logits = list()
        self.actions = list()
        self.dones = list()
        self.rewards = list()
        self.async_flag = False

        # update to divide model policy
        self.dist_model_policy = EqualDistPolicy(
            alg_config["instance_num"],
            prepare_times=self._prepare_times_per_train)

        self.use_train_thread = False
        if self.use_train_thread:
            self.send_train = UniComm("LocalMsg")
            train_thread = threading.Thread(target=self._train_thread)
            train_thread.setDaemon(True)
            train_thread.start()