コード例 #1
0
ファイル: coordinatior.py プロジェクト: ArcQ/tensorflow-kf-1
    def __init__(self):
        config_obj = get_config()
        create_save_dir()
        state_size = config_obj["state_size"]

        self.model = build_ac_model()
        model_input = tf.convert_to_tensor(np.random.random((1, state_size)),
                                           dtype=tf.float32)
        self.model(model_input)
        self.total_steps = 0
        self.game_adapter = GameAdapter.create(get_config()['name'])
コード例 #2
0
ファイル: model.py プロジェクト: ArcQ/tensorflow-kf-1
def build_ac_model():
    from tensorflow.keras.models import Model
    from tensorflow.keras.layers import Dense, Input

    action_size = get_config()['action_size']
    state_size = get_config()['state_size']

    input_layer = Input(shape=(state_size))

    actor_dense = Dense(128, activation='relu')(input_layer)
    actor_logits = Dense(action_size, name='actor_logits')(actor_dense)

    critic_dense = Dense(128, activation='relu')(input_layer)
    critic_values = Dense(1, name='critic_values')(critic_dense)

    return Model(inputs=[input_layer], outputs=[actor_logits, critic_values])
コード例 #3
0
ファイル: master_agent.py プロジェクト: ArcQ/tensorflow-kf-1
    def __init__(self):
        config_obj = get_config()
        create_save_dir()
        state_size = config_obj["state_size"]

        ac_model = build_ac_model()
        model_input = tf.convert_to_tensor(
            np.random.random((1, state_size)),
            dtype=tf.float32
        )
        ac_model(global_model, model_input)

        self.global_model = global_model
        self.global_episode = 0
        self.global_moving_average_reward = 0
        self.best_score = 0
コード例 #4
0
def train_worker(lock, result_queue, worker_idx):
    total_step = 1
    mem = Memory()
    name = "worker-{}-{}".format(get_config()['name'], worker_idx)
    local_model = get_ac_model()
    game_adapter = GameAdapter.create(name)

    done = False
    while not done:
        print('hi')

        current_state = game_adapter.reset()
        mem.clear()

        (ep_steps) = train_ep(local_model, game_adapter, current_state)
        total_step += ep_steps

        lock.acquire()
        result_queue.put(None)
        lock.release()
        if total_step > WORKER_CHECK_STEPS_INTERVAL:
            lock.acquire()
            done = worker_done(master_agent)
            lock.release()
コード例 #5
0
def worker_done(master_agent):
    return master_agent.global_episode > get_config()['max_eps']
コード例 #6
0
ファイル: coordinatior.py プロジェクト: ArcQ/tensorflow-kf-1
 def is_done(self):
     return self.total_steps > get_config()["max_eps"]
コード例 #7
0
ファイル: run.py プロジェクト: ArcQ/tensorflow-kf-1
def train_random():
    config_obj = get_config()
    random_agent = build_random_model(config_obj['max_eps'])
    run_random_model(random_agent)
コード例 #8
0
ファイル: fs.py プロジェクト: ArcQ/tensorflow-kf-1
def get_save_path(file_name):
    config_obj = get_config()
    return os.path.join(get_save_dir(),
                        file_name.format('{}-' + config_obj['name']))
コード例 #9
0
ファイル: fs.py プロジェクト: ArcQ/tensorflow-kf-1
def get_save_dir():
    config_obj = get_config()
    return 'chase_trainer/data/' + config_obj['name']