Beispiel #1
0
    def test_example(self):
        class MLP(elegy.Module):
            def __apply__(self, input):
                mlp = hk.Sequential([
                    hk.Linear(10),
                ])
                return mlp(input)

        callback = elegy.callbacks.EarlyStopping(monitor="loss", patience=3)
        # This callback will stop the training when there is no improvement in
        # the for three consecutive epochs.
        model = elegy.Model(
            module=MLP.defer(),
            loss=elegy.losses.MeanSquaredError(),
            optimizer=optix.rmsprop(0.01),
        )
        history = model.fit(
            np.arange(100).reshape(5, 20).astype(np.float32),
            np.zeros(5),
            epochs=10,
            batch_size=1,
            callbacks=[callback],
            verbose=0,
        )
        assert len(history.history["loss"]) == 7  # Only 7 epochs are run.
def setup_learner():
    """Setup learner for distributed setting"""
    # A thunk that builds a new environment.
    # Substitute your environment here!
    build_env = catch.Catch

    # Construct the agent. We need a sample environment for its spec.
    env_for_spec = build_env()
    num_actions = env_for_spec.action_spec().num_values
    agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(),
                            haiku_nets.CatchNet)

    # Construct the optimizer.
    opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1)

    # Construct the learner.
    learner = learner_lib.Learner(
        agent,
        jax.random.PRNGKey(428),
        opt,
        BATCH_SIZE,
        DISCOUNT_FACTOR,
        FRAMES_PER_ITER,
        max_abs_reward=1.,
        logger=util.AbslLogger(),  # Provide your own logger here.
    )
    return learner
Beispiel #3
0
def main(_):
    # A thunk that builds a new environment.
    # Substitute your environment here!
    build_env = catch.Catch

    # Construct the agent. We need a sample environment for its spec.
    env_for_spec = build_env()
    num_actions = env_for_spec.action_spec().num_values
    agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(),
                            haiku_nets.CatchNet)

    # Construct the optimizer.
    max_updates = MAX_ENV_FRAMES / FRAMES_PER_ITER
    opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1)

    # Construct the learner.
    learner = learner_lib.Learner(
        agent,
        jax.random.PRNGKey(428),
        opt,
        BATCH_SIZE,
        DISCOUNT_FACTOR,
        FRAMES_PER_ITER,
        max_abs_reward=1.,
        logger=util.AbslLogger(),  # Provide your own logger here.
    )

    # Construct the actors on different threads.
    # stop_signal in a list so the reference is shared.
    actor_threads = []
    stop_signal = [False]
    for i in range(NUM_ACTORS):
        actor = actor_lib.Actor(
            agent,
            build_env(),
            UNROLL_LENGTH,
            learner,
            rng_seed=i,
            logger=util.AbslLogger(),  # Provide your own logger here.
        )
        args = (actor, stop_signal)
        actor_threads.append(threading.Thread(target=run_actor, args=args))

    # Start the actors and learner.
    for t in actor_threads:
        t.start()
    learner.run(int(max_updates))

    # Stop.
    stop_signal[0] = True
    for t in actor_threads:
        t.join()
Beispiel #4
0
def main(_):

  # Construct the agent network. We need a sample environment for its spec.
  env = catch.Catch()
  num_actions = env.action_spec().num_values
  net = hk.transform(lambda ts: SimpleNet(num_actions)(ts))  # pylint: disable=unnecessary-lambda

  # Construct the agent and learner.
  agent = Agent(net.apply)
  opt = optix.rmsprop(1e-1, decay=0.99, eps=0.1)
  learner = Learner(agent, opt.update)

  # Initialize the optimizer state.
  sample_ts = env.reset()
  sample_ts = preprocess_step(sample_ts)
  ts_with_batch = jax.tree_map(lambda t: np.expand_dims(t, 0), sample_ts)
  params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch)
  opt_state = opt.init(params)

  # Create accessor and queueing functions.
  current_params = lambda: params
  batch_size = 2
  q = queue.Queue(maxsize=batch_size)

  def dequeue():
    batch = []
    for _ in range(batch_size):
      batch.append(q.get())
    batch = jax.tree_multimap(lambda *ts: np.stack(ts, axis=1), *batch)
    return jax.device_put(batch)

  # Start the actors.
  num_actors = 2
  trajectories_per_actor = 500
  unroll_len = 20
  for i in range(num_actors):
    key = jax.random.PRNGKey(i)
    args = (agent, key, current_params, q.put, unroll_len,
            trajectories_per_actor)
    threading.Thread(target=run_actor, args=args).start()

  # Run the learner.
  num_steps = num_actors * trajectories_per_actor // batch_size
  for i in range(num_steps):
    traj = dequeue()
    params, opt_state = learner.update(params, opt_state, traj)
Beispiel #5
0
    def test_rmsprop(self):
        decay, eps = .9, 0.1

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = optimizers.rmsprop(LR, decay, eps)
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # experimental/optix.py
        optix_params = self.init_params
        rmsprop = optix.rmsprop(LR, decay, eps)
        state = rmsprop.init(optix_params)
        for _ in range(STEPS):
            updates, state = rmsprop.update(self.per_step_updates, state)
            optix_params = optix.apply_updates(optix_params, updates)

        # Check equivalence.
        for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)):
            np.testing.assert_allclose(x, y, rtol=1e-5)
Beispiel #6
0
def main(debug: bool = False, eager: bool = False):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    X_train, y_train, X_test, y_test = dataget.image.mnist(
        global_cache=True).get()

    print("X_train:", X_train.shape, X_train.dtype)
    print("y_train:", y_train.shape, y_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)
    print("y_test:", y_test.shape, y_test.dtype)

    class MLP(elegy.Module):
        """Standard LeNet-300-100 MLP network."""
        def __init__(self, n1: int = 300, n2: int = 100, **kwargs):
            super().__init__(**kwargs)
            self.n1 = n1
            self.n2 = n2

        def call(self, image: jnp.ndarray):

            image = image.astype(jnp.float32) / 255.0

            mlp = hk.Sequential([
                hk.Flatten(),
                hk.Linear(self.n1),
                jax.nn.relu,
                hk.Linear(self.n2),
                jax.nn.relu,
                hk.Linear(10),
            ])
            return dict(outputs=mlp(image))

    model = elegy.Model(
        module=MLP.defer(n1=300, n2=100),
        loss=[
            elegy.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                       on="outputs"),
            elegy.regularizers.GlobalL2(l=1e-4),
        ],
        metrics=elegy.metrics.SparseCategoricalAccuracy.defer(on="outputs"),
        optimizer=optix.rmsprop(1e-3),
        run_eagerly=eager,
    )

    history = model.fit(
        x=X_train,
        y=dict(outputs=y_train),
        epochs=100,
        steps_per_epoch=200,
        batch_size=64,
        validation_data=(X_test, dict(outputs=y_test)),
        shuffle=True,
    )

    plot_history(history)

    # get random samples
    idxs = np.random.randint(0, 10000, size=(9, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot results
    plt.figure(figsize=(12, 12))
    for i in range(3):
        for j in range(3):
            k = 3 * i + j
            plt.subplot(3, 3, k + 1)

            plt.title(f"{np.argmax(y_pred['outputs'][k])}")
            plt.imshow(x_sample[k], cmap="gray")

    plt.show()
Beispiel #7
0
    def __init__(
        self,
        num_agent_steps,
        state_space,
        action_space,
        seed,
        max_grad_norm=None,
        gamma=0.99,
        nstep=1,
        buffer_size=10 ** 6,
        use_per=False,
        batch_size=32,
        start_steps=50000,
        update_interval=4,
        update_interval_target=8000,
        eps=0.01,
        eps_eval=0.001,
        eps_decay_steps=250000,
        loss_type="huber",
        dueling_net=False,
        double_q=False,
        setup_net=True,
        fn=None,
        lr=5e-5,
        lr_cum_p=2.5e-9,
        units=(512,),
        num_quantiles=32,
        num_cosines=64,
    ):
        super(FQF, self).__init__(
            num_agent_steps=num_agent_steps,
            state_space=state_space,
            action_space=action_space,
            seed=seed,
            max_grad_norm=max_grad_norm,
            gamma=gamma,
            nstep=nstep,
            buffer_size=buffer_size,
            batch_size=batch_size,
            use_per=use_per,
            start_steps=start_steps,
            update_interval=update_interval,
            update_interval_target=update_interval_target,
            eps=eps,
            eps_eval=eps_eval,
            eps_decay_steps=eps_decay_steps,
            loss_type=loss_type,
            dueling_net=dueling_net,
            double_q=double_q,
            setup_net=False,
            num_quantiles=num_quantiles,
        )
        if setup_net:
            if fn is None:

                def fn(s, cum_p):
                    return DiscreteImplicitQuantileFunction(
                        num_cosines=num_cosines,
                        action_space=action_space,
                        hidden_units=units,
                        dueling_net=dueling_net,
                    )(s, cum_p)

            self.net, self.params, fake_feature = make_quantile_nerwork(self.rng, state_space, action_space, fn, num_quantiles)
            self.params_target = self.params
            opt_init, self.opt = optix.adam(lr, eps=0.01 / batch_size)
            self.opt_state = opt_init(self.params)

        # Fraction proposal network.
        self.cum_p_net = hk.without_apply_rng(hk.transform(lambda s: CumProbNetwork(num_quantiles=num_quantiles)(s)))
        self.params_cum_p = self.cum_p_net.init(next(self.rng), fake_feature)
        opt_init, self.opt_cum_p = optix.rmsprop(lr_cum_p, decay=0.95, eps=1e-5, centered=True)
        self.opt_state_cum_p = opt_init(self.params_cum_p)
Beispiel #8
0
def main(debug: bool = False, eager: bool = False):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get()

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    class MLP(elegy.Module):
        """Standard LeNet-300-100 MLP network."""
        def __init__(self, n1: int = 300, n2: int = 100, **kwargs):
            super().__init__(**kwargs)
            self.n1 = n1
            self.n2 = n2

        def call(self, image: jnp.ndarray):
            image = image.astype(jnp.float32) / 255.0
            x = hk.Flatten()(image)
            x = hk.Sequential([
                hk.Linear(self.n1),
                jax.nn.relu,
                hk.Linear(self.n2),
                jax.nn.relu,
                hk.Linear(self.n1),
                jax.nn.relu,
                hk.Linear(x.shape[-1]),
                jax.nn.sigmoid,
            ])(x)
            return x.reshape(image.shape) * 255

    class MeanSquaredError(elegy.Loss):
        # we request `x` instead of `y_true` since we are don't require labels in autoencoders
        def call(self, x, y_pred):
            return jnp.mean(jnp.square(x - y_pred), axis=-1)

    model = elegy.Model(
        module=MLP.defer(n1=256, n2=64),
        loss=MeanSquaredError(),
        optimizer=optix.rmsprop(0.001),
        run_eagerly=eager,
    )

    # Notice we are not passing `y`
    history = model.fit(
        x=X_train,
        epochs=20,
        batch_size=64,
        validation_data=(X_test, ),
        shuffle=True,
        callbacks=[elegy.callbacks.TensorBoard()],
    )

    plot_history(history)

    # get random samples
    idxs = np.random.randint(0, 10000, size=(5, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot results
    plt.figure(figsize=(12, 12))
    for i in range(5):
        plt.subplot(2, 5, i + 1)
        plt.imshow(x_sample[i], cmap="gray")
        plt.subplot(2, 5, 5 + i + 1)
        plt.imshow(y_pred[i], cmap="gray")

    plt.show()