Esempio n. 1
0
    def __init__(self, env_name, n_frames):

        self.env_name = env_name
        self.frame_process_func = util.get_preprocess_func(env_name)
        self.n_frames = n_frames
        self.action_space = gym.make(self.env_name).action_space.n
        self.q_network = RecurrentDuelingQNetwork(self.action_space)
        self.define_network()
    def __init__(self,
                 env_id,
                 unroll_steps=5,
                 td_steps=5,
                 n_frames=8,
                 V_min=-30,
                 V_max=30,
                 gamma=0.998,
                 target_update_period=1600):

        self.env_id = env_id

        self.unroll_steps = unroll_steps

        self.td_steps = td_steps

        self.n_frames = n_frames

        self.V_min, self.V_max = V_min, V_max

        self.n_supports = V_max - V_min + 1

        self.supports = tf.range(V_min, V_max + 1, dtype=tf.float32)

        self.gamma = gamma

        self.target_update_period = target_update_period

        self.action_space = gym.make(env_id).action_space.n

        self.repr_network = RepresentationNetwork(
            action_space=self.action_space)

        self.pv_network = PVNetwork(action_space=self.action_space,
                                    V_min=V_min,
                                    V_max=V_max)

        self.target_repr_network = RepresentationNetwork(
            action_space=self.action_space)

        self.target_pv_network = PVNetwork(action_space=self.action_space,
                                           V_min=V_min,
                                           V_max=V_max)

        self.dynamics_network = DynamicsNetwork(action_space=self.action_space,
                                                V_min=V_min,
                                                V_max=V_max)

        self.preprocess_func = util.get_preprocess_func(self.env_id)

        self.optimizer = tf.keras.optimizers.Adam(lr=0.00025)

        self.update_count = 0

        self.setup()
Esempio n. 3
0
    def __init__(self, env_name, target_update_period, n_frames, gamma, eta,
                 alpha, burnin_length, unroll_length):
        self.env_name = env_name
        self.n_frames = n_frames
        self.action_space = gym.make(self.env_name).action_space.n
        self.frame_process_func = util.get_preprocess_func(env_name)

        self.q_network = RecurrentDuelingQNetwork(self.action_space)
        self.target_q_network = RecurrentDuelingQNetwork(self.action_space)
        self.target_update_period = target_update_period
        self.optimizer = tf.keras.optimizers.Adam(lr=0.00025, epsilon=0.001)

        self.gamma = gamma
        self.eta = eta
        self.alpha = alpha

        self.burnin_len = burnin_length
        self.unroll_len = unroll_length

        self.num_updated = 0
Esempio n. 4
0
    def __init__(self, pid, env_id, n_frames,
                 num_mcts_simulations, unroll_steps,
                 gamma, V_max, V_min, td_steps,
                 dirichlet_alpha, initial_randomize=True):

        self.pid = pid

        self.env_id = env_id

        self.unroll_steps = unroll_steps

        self.num_mcts_simulations = num_mcts_simulations

        self.action_space = gym.make(env_id).action_space.n

        self.n_frames = n_frames

        self.gamma = gamma

        self.dirichlet_alpha = dirichlet_alpha

        self.V_min, self.V_max = V_min, V_max

        self.td_steps = td_steps

        self.preprocess_func = util.get_preprocess_func(self.env_id)

        self.repr_network = RepresentationNetwork(
            action_space=self.action_space)

        self.pv_network = PVNetwork(action_space=self.action_space,
                                    V_min=V_min, V_max=V_max)

        self.dynamics_network = DynamicsNetwork(action_space=self.action_space,
                                                V_min=V_min, V_max=V_max)

        self.initial_randomize = initial_randomize

        self.setup()

        self.reset_env()
Esempio n. 5
0
    def __init__(self, pid, env_name, n_frames,
                 epsilon, gamma, eta, alpha,
                 nstep, burnin_length, unroll_length):

        self.pid = pid
        self.env_name = env_name
        self.action_space = gym.make(env_name).action_space.n
        self.frame_process_func = util.get_preprocess_func(self.env_name)
        self.n_frames = n_frames

        self.q_network = RecurrentDuelingQNetwork(self.action_space)
        self.epsilon = epsilon
        self.gamma = gamma

        self.eta = eta
        self.alpha = alpha  # priority exponent

        self.nstep = nstep
        self.burnin_len = burnin_length
        self.unroll_len = unroll_length

        self.define_network()
Esempio n. 6
0
    def __init__(self,
                 env_id: str,
                 config: Config,
                 pid: int = None,
                 epsilon: float = 0.,
                 summary_writer: tf.summary.SummaryWriter = None):

        self.env_id = env_id

        self.config = config

        self.pid = pid

        self.epsilon = epsilon

        self.summary_writer = summary_writer

        self.action_space = gym.make(self.env_id).action_space.n

        self.preprocess_func = util.get_preprocess_func(env_name=self.env_id)

        self.buffer = EpisodeBuffer(seqlen=self.config.sequence_length)

        self.world_model = WorldModel(config)
        self.wm_optimizer = tf.keras.optimizers.Adam(lr=self.config.lr_world,
                                                     epsilon=1e-4)

        self.policy = PolicyNetwork(action_space=self.action_space)
        self.policy_optimizer = tf.keras.optimizers.Adam(
            lr=self.config.lr_actor, epsilon=1e-5)

        self.value = ValueNetwork(action_space=self.action_space)
        self.target_value = ValueNetwork(action_space=self.action_space)
        self.value_optimizer = tf.keras.optimizers.Adam(
            lr=self.config.lr_critic, epsilon=1e-5)

        self.setup()
        x = self.bn2(self.conv2(x), training=training)
        x = x + inputs  #: skip connection
        x = relu(x)

        return x


if __name__ == '__main__':
    import time
    import gym

    import util

    n_frames = 8
    env_name = "BreakoutDeterministic-v4"
    f = util.get_preprocess_func(env_name)

    env = gym.make(env_name)
    action_space = env.action_space.n

    frame = f(env.reset())

    frame_history = [frame] * n_frames
    action_history = [0, 1, 2, 3, 0, 1, 2, 3]

    repr_function = RepresentationNetwork(action_space=action_space)
    dynamics_function = DynamicsNetwork(action_space=action_space, V_min=-30, V_max=30)
    pv_network = PVNetwork(action_space=action_space, V_min=-30, V_max=30)

    hidden_state, obs = repr_function.predict(frame_history, action_history)
    hidden_states = tf.repeat(hidden_state, repeats=4, axis=0)
    def __init__(self, action_space):

        super(ValueNetwork, self).__init__()

        self.mlp = MLPHead(out_shape=1)

    @tf.function
    def call(self, feat):

        value = self.mlp(feat)

        return value


if __name__ == '__main__':
    import gym
    from PIL import Image
    import util

    envname = "BreakoutDeterministic-v4"
    env = gym.make(envname)
    preprocess_func = util.get_preprocess_func(envname)
    obs = preprocess_func(env.reset())
    print(obs.shape)
    obs = obs[np.newaxis, ...]
    print(obs.shape)

    encoder = Encoder()
    s = encoder(obs)
    print(s.shape)
Esempio n. 9
0
def visualize(env_id="BreakoutDeterministic-v4",
              n_frames=4,
              V_min=-30,
              V_max=30):

    env = gym.make(env_id)

    action_space = env.action_space.n

    preprocess_func = util.get_preprocess_func(env_id)

    frame = preprocess_func(env.reset())

    frame_history = collections.deque([frame] * n_frames, maxlen=n_frames)

    action_history = collections.deque([0] * n_frames, maxlen=n_frames)

    repr_network = RepresentationNetwork(action_space=action_space)
    repr_network.load_weights("checkpoints/repr_net")

    pv_network = PVNetwork(action_space=action_space, V_min=V_min, V_max=V_max)
    pv_network.load_weights("checkpoints/pv_net")

    dynamics_network = DynamicsNetwork(action_space=action_space,
                                       V_min=V_min,
                                       V_max=V_max)
    dynamics_network.load_weights("checkpoints/dynamics_net")

    mcts = AtariMCTS(action_space=action_space,
                     pv_network=pv_network,
                     dynamics_network=dynamics_network,
                     gamma=0.997,
                     dirichlet_alpha=None)

    episode_rewards, episode_steps = 0, 0

    done = False

    images = []

    while not done:

        hidden_state, obs = repr_network.predict(frame_history, action_history)

        mcts_policy, action, root_value = mcts.search(hidden_state, 20, T=0.1)

        next_hidden_state, reward_pred = dynamics_network.predict(
            hidden_state, action)
        reward_pred = reward_pred.numpy()[0][0]

        frame, reward, done, info = env.step(action)

        print()
        print("STEP:", episode_steps)
        print("Reward", reward)
        print("Action", action)

        #: shape = (160, 210, 3)
        img_frame = Image.fromarray(frame)

        img_desc = Image.new('RGB', (280, 210), color="black")
        fnt = ImageFont.truetype("arial.ttf", 18)
        fnt_sm = ImageFont.truetype("arial.ttf", 12)

        pl = 30
        pb = 30

        v = str(round(root_value, 2))
        p = str([round(prob, 2) for prob in mcts_policy])
        r = str(round(reward_pred, 2))

        draw = ImageDraw.Draw(img_desc)
        draw.text((pl, 20), f"V(s): {v}", font=fnt, fill="white")
        draw.text((pl, 20 + pb), f"R(s, a): {r}", font=fnt, fill="white")
        draw.text((pl, 20 + pb * 2), f"π(s): {p}", font=fnt, fill="white")

        draw.text((pl, 20 + pb * 3.5), f"Note:", font=fnt_sm, fill="white")
        draw.text((pl, 20 + pb * 4),
                  "{ 0: Noop, 1: FIRE, 2: Left, 3: Right }",
                  font=fnt_sm,
                  fill="white")

        img_bg = Image.new(
            'RGB', (img_frame.width + img_desc.width, img_frame.height))

        img_bg.paste(img_frame, (0, 0))
        img_bg.paste(img_desc, (img_frame.width, 0))

        images.append(img_bg)

        episode_rewards += reward

        episode_steps += 1

        frame_history.append(preprocess_func(frame))

        action_history.append(action)

    print()
    print("====" * 5)
    print("FINISH")
    print(episode_steps, episode_rewards)

    images[0].save('tmp/muzero.gif',
                   save_all=True,
                   append_images=images[1:],
                   optimize=False,
                   duration=60,
                   loop=0)

    return episode_rewards, episode_steps