Beispiel #1
0
    def __init__(self, policy: str, action_dim: int, max_action: float,
                 lr: float, discount: float, noise_clip: float,
                 policy_noise: float, policy_freq: int, actor_rng: jnp.ndarray,
                 critic_rng: jnp.ndarray, sample_state: np.ndarray):
        self.discount = discount
        self.noise_clip = noise_clip
        self.policy_noise = policy_noise
        self.policy_freq = policy_freq
        self.max_action = max_action
        self.td3_update = policy == 'TD3'

        self.actor = hk.transform(lambda x: Actor(action_dim, max_action)(x))
        actor_opt_init, self.actor_opt_update = optix.adam(lr)

        self.critic = hk.transform(lambda x: Critic()(x))
        critic_opt_init, self.critic_opt_update = optix.adam(lr)

        self.actor_params = self.target_actor_params = self.actor.init(
            actor_rng, sample_state)
        self.actor_opt_state = actor_opt_init(self.actor_params)

        action = self.actor.apply(self.actor_params, sample_state)

        self.critic_params = self.target_critic_params = self.critic.init(
            critic_rng, jnp.concatenate((sample_state, action), 0))
        self.critic_opt_state = critic_opt_init(self.critic_params)

        self.updates = 0
Beispiel #2
0
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> BootstrappedDqn:
    """Initialize a Bootstrapped DQN agent with default parameters."""

    # Define network.
    prior_scale = 3.
    hidden_sizes = [50, 50]

    def network(inputs: jnp.ndarray) -> jnp.ndarray:
        """Simple Q-network with randomized prior function."""
        net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
        prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
        x = hk.Flatten()(inputs)
        return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

    optimizer = optix.adam(learning_rate=1e-3)
    return BootstrappedDqn(
        obs_spec=obs_spec,
        action_spec=action_spec,
        network=hk.transform(network),
        batch_size=128,
        discount=.99,
        num_ensemble=20,
        replay_capacity=10000,
        min_replay_size=128,
        sgd_period=1,
        target_update_period=4,
        optimizer=optimizer,
        mask_prob=0.5,
        noise_scale=0.,
        epsilon_fn=lambda _: 0.,
        seed=seed,
    )
Beispiel #3
0
    def __init__(
            self,
            num_critics=2,
            fn_error=None,
            lr_error=3e-4,
            units_error=(256, 256, 256),
            d2rl=False,
            init_error=10.0,
    ):
        if fn_error is None:

            def fn_error(s, a):
                return ContinuousQFunction(
                    num_critics=num_critics,
                    hidden_units=units_error,
                    d2rl=d2rl,
                )(s, a)

        # Error model.
        self.error = hk.without_apply_rng(hk.transform(fn_error))
        self.params_error = self.params_error_target = self.error.init(
            next(self.rng), *self.fake_args_critic)
        opt_init, self.opt_error = optix.adam(lr_error)
        self.opt_state_error = opt_init(self.params_error)
        # Running mean of error.
        self.rm_error_list = [
            jnp.array(init_error, dtype=jnp.float32)
            for _ in range(num_critics)
        ]
Beispiel #4
0
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Creates an actor-critic agent with default hyperparameters."""

  hidden_size = 256
  initial_rnn_state = hk.LSTMState(
      hidden=jnp.zeros((1, hidden_size), dtype=jnp.float32),
      cell=jnp.zeros((1, hidden_size), dtype=jnp.float32))

  def network(inputs: jnp.ndarray,
              state) -> Tuple[Tuple[Logits, Value], LSTMState]:
    flat_inputs = hk.Flatten()(inputs)
    torso = hk.nets.MLP([hidden_size, hidden_size])
    lstm = hk.LSTM(hidden_size)
    policy_head = hk.Linear(action_spec.num_values)
    value_head = hk.Linear(1)

    embedding = torso(flat_inputs)
    embedding, state = lstm(embedding, state)
    logits = policy_head(embedding)
    value = value_head(embedding)
    return (logits, jnp.squeeze(value, axis=-1)), state

  return ActorCriticRNN(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      initial_rnn_state=initial_rnn_state,
      optimizer=optix.adam(3e-3),
      rng=hk.PRNGSequence(seed),
      sequence_length=32,
      discount=0.99,
      td_lambda=0.9,
  )
Beispiel #5
0
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Creates an actor-critic agent with default hyperparameters."""

  def network(inputs: jnp.ndarray) -> Tuple[Logits, Value]:
    flat_inputs = hk.Flatten()(inputs)
    torso = hk.nets.MLP([64, 64])
    policy_head = hk.Linear(action_spec.num_values)
    value_head = hk.Linear(1)
    embedding = torso(flat_inputs)
    logits = policy_head(embedding)
    value = value_head(embedding)
    return logits, jnp.squeeze(value, axis=-1)

  return ActorCritic(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      optimizer=optix.adam(3e-3),
      rng=hk.PRNGSequence(seed),
      sequence_length=32,
      discount=0.99,
      td_lambda=0.9,
  )
Beispiel #6
0
def main(_):
    # Create the dataset.
    train_dataset, vocab_size = dataset.load(FLAGS.batch_size,
                                             FLAGS.sequence_length)
    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads,
                                  FLAGS.num_layers, FLAGS.dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optix.chain(optix.clip_by_global_norm(FLAGS.grad_clip_value),
                            optix.adam(FLAGS.learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)
    updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
        # We use JAX runahead to mask data preprocessing and JAX dispatch overheads.
        # Using values from state/metrics too often will block the runahead and can
        # cause these overheads to become more prominent.
        if step % LOG_EVERY == 0:
            steps_per_sec = LOG_EVERY / (time.time() - prev_time)
            prev_time = time.time()
            metrics.update({'steps_per_sec': steps_per_sec})
            logging.info({k: float(v) for k, v in metrics.items()})
Beispiel #7
0
def update(state: TrainingState, batch: dataset.Batch) -> TrainingState:
    """Does a step of SGD given inputs & targets."""
    _, optimizer = optix.adam(FLAGS.learning_rate)
    _, loss_fn = hk.without_apply_rng(hk.transform(sequence_loss))
    gradients = jax.grad(loss_fn)(state.params, batch)
    updates, new_opt_state = optimizer(gradients, state.opt_state)
    new_params = optix.apply_updates(state.params, updates)
    return TrainingState(params=new_params, opt_state=new_opt_state)
Beispiel #8
0
 def __init__(self, obs_spec, action_spec, epsilon, learning_rate):
     self._obs_spec = obs_spec
     self._action_spec = action_spec
     self._epsilon = epsilon
     # Neural net and optimiser.
     self._network = build_network(action_spec.num_values)
     self._optimizer = optix.adam(learning_rate)
     # Jitting for speed.
     self.actor_step = jax.jit(self.actor_step)
     self.learner_step = jax.jit(self.learner_step)
Beispiel #9
0
    def __init__(
            self,
            observation_spec,
            action_spec,
            params: RlaxRainbowParams = RlaxRainbowParams()):

        if not callable(params.epsilon):
            eps = params.epsilon
            params = params._replace(epsilon=lambda ts: eps)
        if not callable(params.beta_is):
            beta = params.beta_is
            params = params._replace(beta_is=lambda ts: beta)
        self.params = params
        self.rng = hk.PRNGSequence(jax.random.PRNGKey(params.seed))

        # Build and initialize Q-network.
        def build_network(
                layers: List[int],
                output_shape: List[int]) -> hk.Transformed:

            def q_net(obs):
                layers_ = tuple(layers) + (onp.prod(output_shape), )
                network = NoisyMLP(layers_)
                return hk.Reshape(output_shape=output_shape)(network(obs))

            return hk.transform(q_net)

        self.network = build_network(params.layers,
                                     (action_spec.num_values, params.n_atoms))
        self.trg_params = self.network.init(
            next(self.rng), observation_spec.generate_value().astype(onp.float16))
        self.online_params = self.trg_params
        self.atoms = jnp.tile(jnp.linspace(-params.atom_vmax, params.atom_vmax, params.n_atoms),
                              (action_spec.num_values, 1))

        # Build and initialize optimizer.
        self.optimizer = optix.adam(params.learning_rate, eps=3.125e-5)
        self.opt_state = self.optimizer.init(self.online_params)
        self.train_step = 0
        self.update_q = DQNLearning.update_q

        if params.use_priority:
            self.experience = PriorityBuffer(
                observation_spec.shape[1],
                action_spec.num_values,
                1,
                params.experience_buffer_size)
        else:
            self.experience = ExperienceBuffer(
                observation_spec.shape[1],
                action_spec.num_values,
                1,
                params.experience_buffer_size)
        self.last_obs = onp.empty(observation_spec.shape)
        self.requires_vectorized_observation = lambda: True
Beispiel #10
0
 def __init__(self, config: DDPGConfig):
     """Initialize the algorithm."""
     pi_net = partial(mlp_policy_net,
                      output_sizes=config.pi_net_size +
                      (config.action_dim, ))
     pi_net = hk.transform(pi_net)
     q_net = partial(mlp_value_net, output_sizes=config.q_net_size + (1, ))
     q_net = hk.transform(q_net)
     pi_optimizer = optix.adam(learning_rate=config.learning_rate)
     q_optimizer = optix.adam(learning_rate=config.learning_rate)
     self.func = DDPGFunc(pi_net, q_net, pi_optimizer, q_optimizer,
                          config.gamma, config.state_dim, config.action_dim)
     rng = jax.random.PRNGKey(config.seed)  # rundom number generator
     state = jnp.zeros(config.state_dim)
     state_action = jnp.zeros(config.state_dim + config.action_dim)
     pi_params = pi_net.init(rng, state)
     q_params = q_net.init(rng, state_action)
     pi_opt_state = pi_optimizer.init(pi_params)
     q_opt_state = q_optimizer.init(q_params)
     self.state = DDPGState(pi_params, q_params, pi_opt_state, q_opt_state)
     return
Beispiel #11
0
def main(debug: bool = False, eager: bool = False):

    if debug:
        import debugpy

        print("Waiting for debugger...")
        debugpy.listen(5678)
        debugpy.wait_for_client()

    X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get()
    # Now binarize data
    X_train = (X_train > 0).astype(jnp.float32)
    X_test = (X_test > 0).astype(jnp.float32)

    print("X_train:", X_train.shape, X_train.dtype)
    print("X_test:", X_test.shape, X_test.dtype)

    model = elegy.Model(
        module=VariationalAutoEncoder.defer(),
        loss=[KLDivergence(), BinaryCrossEntropy(on="logits")],
        optimizer=optix.adam(1e-3),
        run_eagerly=eager,
    )

    epochs = 10

    # Fit with datasets in memory
    history = model.fit(
        x=X_train,
        epochs=epochs,
        batch_size=64,
        steps_per_epoch=100,
        validation_data=(X_test, ),
        shuffle=True,
    )
    plot_history(history)

    # get random samples
    idxs = np.random.randint(0, len(X_test), size=(5, ))
    x_sample = X_test[idxs]

    # get predictions
    y_pred = model.predict(x=x_sample)

    # plot results
    plt.figure(figsize=(12, 12))
    for i in range(5):
        plt.subplot(2, 5, i + 1)
        plt.imshow(x_sample[i], cmap="gray")
        plt.subplot(2, 5, 5 + i + 1)
        plt.imshow(y_pred["image"][i], cmap="gray")

    plt.show()
Beispiel #12
0
 def __init__(self, obs_spec, action_spec, epsilon_cfg, target_period,
              learning_rate):
     self._obs_spec = obs_spec
     self._action_spec = action_spec
     self._target_period = target_period
     # Neural net and optimiser.
     self._network = build_network(action_spec.num_values)
     self._optimizer = optix.adam(learning_rate)
     self._epsilon_by_frame = rlax.polynomial_schedule(**epsilon_cfg)
     # Jitting for speed.
     self.actor_step = jax.jit(self.actor_step)
     self.learner_step = jax.jit(self.learner_step)
Beispiel #13
0
def run(bsuite_id: str) -> str:
    """Runs a DQN agent on a given bsuite environment, logging to CSV."""

    env = bsuite.load_and_record(
        bsuite_id=bsuite_id,
        save_path=FLAGS.save_path,
        logging_mode=FLAGS.logging_mode,
        overwrite=FLAGS.overwrite,
    )
    action_spec = env.action_spec()

    # Define network.
    prior_scale = 5.
    hidden_sizes = [50, 50]

    def network(inputs: jnp.ndarray) -> jnp.ndarray:
        """Simple Q-network with randomized prior function."""
        net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
        prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
        x = hk.Flatten()(inputs)
        return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

    optimizer = optix.adam(learning_rate=1e-3)

    agent = boot_dqn.BootstrappedDqn(
        obs_spec=env.observation_spec(),
        action_spec=action_spec,
        network=network,
        optimizer=optimizer,
        num_ensemble=FLAGS.num_ensemble,
        batch_size=128,
        discount=.99,
        replay_capacity=10000,
        min_replay_size=128,
        sgd_period=1,
        target_update_period=4,
        mask_prob=1.0,
        noise_scale=0.,
    )

    num_episodes = FLAGS.num_episodes or getattr(env, 'bsuite_num_episodes')
    experiment.run(agent=agent,
                   environment=env,
                   num_episodes=num_episodes,
                   verbose=FLAGS.verbose)

    return bsuite_id
Beispiel #14
0
def main():
    net = hk.transform(behavioral_cloning)
    opt = optix.adam(0.001)

    @jax.jit
    def loss(params, batch):
        """ The loss criterion for our model """
        logits = net.apply(params, None, batch)
        return mse_loss(logits, batch[2])

    @jax.jit
    def update(opt_state, params, batch):
        grads = jax.grad(loss)(params, batch)
        updates, opt_state = opt.update(grads, opt_state)
        params = optix.apply_updates(params, updates)
        return params, opt_state

    @jax.jit
    def accuracy(params, batch):
        """ Simply report the loss for the current batch """
        logits = net.apply(params, None, batch)
        return mse_loss(logits, batch[2])

    train_dataset, val_dataset = load_data(MINERL_ENV,
                                           batch_size=32,
                                           epochs=100)

    rng = jax.random.PRNGKey(2020)
    batch = next(train_dataset)
    params = net.init(rng, batch)
    opt_state = opt.init(params)

    for i, batch in enumerate(train_dataset):
        params, opt_state = update(opt_state, params, batch)
        if i % 1000 == 0:
            print(accuracy(params, val_dataset))

        if i % 10000 == 0:
            with open(PARAMS_FILENAME, 'wb') as fh:
                dill.dump(params, fh)

    with open(PARAMS_FILENAME, 'wb') as fh:
        dill.dump(params, fh)
Beispiel #15
0
def main(_):
    model = hk.transform(lambda x: VariationalAutoEncoder()(x))  # pylint: disable=unnecessary-lambda
    optimizer = optix.adam(FLAGS.learning_rate)

    @jax.jit
    def loss_fn(params: hk.Params, rng_key: PRNGKey,
                batch: Batch) -> jnp.ndarray:
        """ELBO loss: E_p[log(x)] - KL(d||q), where p ~ Be(0.5) and q ~ N(0,1)."""
        outputs: VAEOutput = model.apply(params, rng_key, batch["image"])

        log_likelihood = -binary_cross_entropy(batch["image"], outputs.logits)
        kl = kl_gaussian(outputs.mean, outputs.stddev**2)
        elbo = log_likelihood - kl

        return -jnp.mean(elbo)

    @jax.jit
    def update(
        params: hk.Params,
        rng_key: PRNGKey,
        opt_state: OptState,
        batch: Batch,
    ) -> Tuple[hk.Params, OptState]:
        """Single SGD update step."""
        grads = jax.grad(loss_fn)(params, rng_key, batch)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optix.apply_updates(params, updates)
        return new_params, new_opt_state

    rng_seq = hk.PRNGSequence(42)
    params = model.init(next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE)))
    opt_state = optimizer.init(params)

    train_ds = load_dataset(tfds.Split.TRAIN, FLAGS.batch_size)
    valid_ds = load_dataset(tfds.Split.TEST, FLAGS.batch_size)

    for step in range(FLAGS.training_steps):
        params, opt_state = update(params, next(rng_seq), opt_state,
                                   next(train_ds))

        if step % FLAGS.eval_frequency == 0:
            val_loss = loss_fn(params, next(rng_seq), next(valid_ds))
            logging.info("STEP: %5d; Validation ELBO: %.3f", step, -val_loss)
Beispiel #16
0
    def __init__(
            self,
            observation_spec,
            action_spec,
            target_update_period: int = None,
            discount: float = None,
            epsilon=lambda x: 0.1,
            learning_rate: float = 0.001,
            layers: List[int] = None,
            use_double_q=True,
            use_priority=True,
            seed: int = 1234,
            importance_beta: float = 0.4):
        self.rng = hk.PRNGSequence(jax.random.PRNGKey(seed))

        # Build and initialize Q-network.
        self.layers = layers or (512,)
        self.network = build_network(self.layers, action_spec.num_values)
        #  sample_input = env.observation_spec()["observation"].generate_value()
        #  sample_input = jnp.zeros(observation_le)
        self.trg_params = self.network.init(next(self.rng), observation_spec.generate_value().astype(onp.float16))
        self.online_params = self.trg_params#self.network.init(next(self.rng), sample_input)

        # Build and initialize optimizer.
        self.optimizer = optix.adam(learning_rate)
        self.opt_state = self.optimizer.init(self.online_params)
        self.epsilon = epsilon
        self.train_steps = 0
        self.target_update_period = target_update_period or 500
        self.discount = discount or 0.99
        #  if use_double_q:
        #      self.update_q = DQNLearning.update_q_double
        #  else:
        #      self.update_q = DQNLearning.update_q
        self.update_q = DQNLearning.update_q

        if use_priority:
            self.experience = PriorityBuffer(observation_spec.shape[1],  action_spec.num_values, 1, 2**19)
        else:
            self.experience = ExperienceBuffer(observation_spec.shape[1], action_spec.num_values, 1, 2**19)
        self.importance_beta = importance_beta
        self.last_obs = onp.empty(observation_spec.shape)
Beispiel #17
0
    def test_adam(self):
        b1, b2, eps = 0.9, 0.999, 1e-8

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = optimizers.adam(LR, b1, b2, eps)
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # experimental/optix.py
        optix_params = self.init_params
        adam = optix.adam(LR, b1, b2, eps)
        state = adam.init(optix_params)
        for _ in range(STEPS):
            updates, state = adam.update(self.per_step_updates, state)
            optix_params = optix.apply_updates(optix_params, updates)

        # Check equivalence.
        for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)):
            np.testing.assert_allclose(x, y, rtol=1e-4)
Beispiel #18
0
    def test_graph_network_learning(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        R_key, dr0_key, params_key = random.split(key, 3)

        d, _ = space.free()

        R = random.uniform(R_key, (6, 3, spatial_dimension), dtype=dtype)
        dr0 = random.uniform(dr0_key, (6, 3, 3), dtype=dtype)
        E_gt = vmap(
          lambda R, dr0: \
          np.sum((space.distance(space.map_product(d)(R, R)) - dr0) ** 2))

        cutoff = 0.2

        init_fn, energy_fn = energy.graph_network(d, cutoff)
        params = init_fn(params_key, R[0])

        @jit
        def loss(params, R):
            return np.mean((vmap(energy_fn,
                                 (None, 0))(params, R) - E_gt(R, dr0))**2)

        opt = optix.chain(optix.clip_by_global_norm(1.0), optix.adam(1e-4))

        @jit
        def update(params, opt_state, R):
            updates, opt_state = opt.update(grad(loss)(params, R), opt_state)
            return optix.apply_updates(params, updates), opt_state

        opt_state = opt.init(params)

        l0 = loss(params, R)
        for i in range(4):
            params, opt_state = update(params, opt_state, R)

        assert loss(params, R) < l0 * 0.95
Beispiel #19
0
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
    """Initialize a DQN agent with default parameters."""
    def network(inputs: jnp.ndarray) -> jnp.ndarray:
        flat_inputs = hk.Flatten()(inputs)
        mlp = hk.nets.MLP([64, 64, action_spec.num_values])
        action_values = mlp(flat_inputs)
        return action_values

    return DQN(
        obs_spec=obs_spec,
        action_spec=action_spec,
        network=hk.transform(network),
        optimizer=optix.adam(1e-3),
        batch_size=32,
        discount=0.99,
        replay_capacity=10000,
        min_replay_size=100,
        sgd_period=1,
        target_update_period=4,
        epsilon=0.05,
        rng=hk.PRNGSequence(seed),
    )
Beispiel #20
0
def make_optimizer() -> optix.InitUpdate:
    """Defines the optimizer."""
    return optix.adam(FLAGS.learning_rate)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    if FLAGS.hidden_layer_sizes is None:
        # Cannot pass default arguments as lists due to style requirements, so we
        # override it here if they are not set.
        FLAGS.hidden_layer_sizes = DEFAULT_LAYER_SIZES

    # Make the network.
    net = hk.transform(net_fn)

    # Make the optimiser.
    opt = optix.adam(FLAGS.step_size)

    @jax.jit
    def loss(
        params: Params,
        inputs: np.ndarray,
        targets: np.ndarray,
    ) -> jnp.DeviceArray:
        """Cross-entropy loss."""
        assert targets.dtype == np.int32
        log_probs = net.apply(params, inputs)
        return -jnp.mean(one_hot(targets, NUM_ACTIONS) * log_probs)

    @jax.jit
    def accuracy(
        params: Params,
        inputs: np.ndarray,
        targets: np.ndarray,
    ) -> jnp.DeviceArray:
        """Classification accuracy."""
        predictions = net.apply(params, inputs)
        return jnp.mean(jnp.argmax(predictions, axis=-1) == targets)

    @jax.jit
    def update(
        params: Params,
        opt_state: OptState,
        inputs: np.ndarray,
        targets: np.ndarray,
    ) -> Tuple[Params, OptState]:
        """Learning rule (stochastic gradient descent)."""
        _, gradient = jax.value_and_grad(loss)(params, inputs, targets)
        updates, opt_state = opt.update(gradient, opt_state)
        new_params = optix.apply_updates(params, updates)
        return new_params, opt_state

    def output_samples(params: Params, max_samples: int):
        """Output some cases where the policy disagrees with the dataset action."""
        if max_samples == 0:
            return
        count = 0
        with open(os.path.join(FLAGS.data_path, 'test.txt')) as f:
            lines = list(f)
        np.random.shuffle(lines)
        for line in lines:
            state = GAME.new_initial_state()
            actions = _trajectory(line)
            for action in actions:
                if not state.is_chance_node():
                    observation = np.array(state.information_state_tensor(),
                                           np.float32)
                    policy = np.exp(net.apply(params, observation))
                    probs_actions = [(p, a) for a, p in enumerate(policy)]
                    pred = max(probs_actions)[1]
                    if pred != action:
                        print(state)
                        for p, a in reversed(
                                sorted(probs_actions)[-TOP_K_ACTIONS:]):
                            print('{:7} {:.2f}'.format(
                                state.action_to_string(a), p))
                        print('Ground truth {}\n'.format(
                            state.action_to_string(action)))
                        count += 1
                        break
                state.apply_action(action)
            if count >= max_samples:
                return

    # Store what we need to rebuild the Haiku net.
    if FLAGS.save_path:
        filename = os.path.join(FLAGS.save_path, 'layers.txt')
        with open(filename, 'w') as layer_def_file:
            for s in FLAGS.hidden_layer_sizes:
                layer_def_file.write(f'{s} ')
            layer_def_file.write('\n')

    # Make datasets.
    if FLAGS.data_path is None:
        raise app.UsageError(
            'Please generate your own supervised training data and supply the local'
            'location as --data_path')
    train = batch(make_dataset(os.path.join(FLAGS.data_path, 'train.txt')),
                  FLAGS.train_batch)
    test = batch(make_dataset(os.path.join(FLAGS.data_path, 'test.txt')),
                 FLAGS.eval_batch)

    # Initialize network and optimiser.
    if FLAGS.checkpoint_file:
        with open(FLAGS.checkpoint_file, 'rb') as pkl_file:
            params, opt_state = pickle.load(pkl_file)
    else:
        rng = jax.random.PRNGKey(
            FLAGS.rng_seed)  # seed used for network weights
        inputs, unused_targets = next(train)
        params = net.init(rng, inputs)
        opt_state = opt.init(params)

    # Train/eval loop.
    for step in range(FLAGS.iterations):
        # Do SGD on a batch of training examples.
        inputs, targets = next(train)
        params, opt_state = update(params, opt_state, inputs, targets)

        # Periodically evaluate classification accuracy on the test set.
        if (1 + step) % FLAGS.eval_every == 0:
            inputs, targets = next(test)
            test_accuracy = accuracy(params, inputs, targets)
            print(f'After {1+step} steps, test accuracy: {test_accuracy}.')
            if FLAGS.save_path:
                filename = os.path.join(FLAGS.save_path,
                                        f'checkpoint-{1 + step}.pkl')
                with open(filename, 'wb') as pkl_file:
                    pickle.dump((params, opt_state), pkl_file)
            output_samples(params, FLAGS.num_examples)
Beispiel #22
0
    def __init__(
        self,
        num_agent_steps,
        state_space,
        action_space,
        seed,
        max_grad_norm=None,
        gamma=0.99,
        nstep=1,
        num_critics=1,
        buffer_size=10**6,
        use_per=False,
        batch_size=256,
        start_steps=10000,
        update_interval=1,
        tau=5e-3,
        fn_actor=None,
        fn_critic=None,
        lr_actor=1e-3,
        lr_critic=1e-3,
        units_actor=(256, 256),
        units_critic=(256, 256),
        d2rl=False,
        std=0.1,
        update_interval_policy=2,
    ):
        super(DDPG, self).__init__(
            num_agent_steps=num_agent_steps,
            state_space=state_space,
            action_space=action_space,
            seed=seed,
            max_grad_norm=max_grad_norm,
            gamma=gamma,
            nstep=nstep,
            num_critics=num_critics,
            buffer_size=buffer_size,
            use_per=use_per,
            batch_size=batch_size,
            start_steps=start_steps,
            update_interval=update_interval,
            tau=tau,
        )
        if d2rl:
            self.name += "-D2RL"

        if fn_critic is None:

            def fn_critic(s, a):
                return ContinuousQFunction(
                    num_critics=num_critics,
                    hidden_units=units_critic,
                    d2rl=d2rl,
                )(s, a)

        if fn_actor is None:

            def fn_actor(s):
                return DeterministicPolicy(
                    action_space=action_space,
                    hidden_units=units_actor,
                    d2rl=d2rl,
                )(s)

        # Critic.
        self.critic = hk.without_apply_rng(hk.transform(fn_critic))
        self.params_critic = self.params_critic_target = self.critic.init(
            next(self.rng), *self.fake_args_critic)
        opt_init, self.opt_critic = optix.adam(lr_critic)
        self.opt_state_critic = opt_init(self.params_critic)

        # Actor.
        self.actor = hk.without_apply_rng(hk.transform(fn_actor))
        self.params_actor = self.params_actor_target = self.actor.init(
            next(self.rng), *self.fake_args_actor)
        opt_init, self.opt_actor = optix.adam(lr_actor)
        self.opt_state_actor = opt_init(self.params_actor)

        # Other parameters.
        self.std = std
        self.update_interval_policy = update_interval_policy
Beispiel #23
0
def train(batch_size, max_iter, num_test_samples):
    model = make_network(4, 8, output_dim=2)
    optimizer = optix.adam(0.005)

    data_generator = collect_batch(
        dataset.FamilyTreeDataset(
            task="grandparents",
            epoch_size=10,
            nmin=20),
        n=batch_size,
    )

    @jax.jit
    def evaluate_batch(params, data):
        logits = model.apply(params, data.predicates)
        correct = jnp.all(data.targets == jnp.argmax(logits, axis=-1), axis=-1)
        ce_loss = jnp.sum(cross_entropy(data.targets, logits), axis=-1)
        return correct, ce_loss

    def evaluate(params):
        eval_data_gen = collect_batch(
            dataset.FamilyTreeDataset(
                task="grandparents",
                epoch_size=10,
                nmin=50),
            n=batch_size,
        )
        correct, ce_loss = zip(*[
            evaluate_batch(params, preprocess_batch(data))
            for data in itertools.islice(eval_data_gen, num_test_samples)
        ])
        correct = np.concatenate(correct, axis=0)
        ce_loss = np.concatenate(ce_loss, axis=0)

        return np.mean(correct), np.mean(ce_loss)

    def loss(params, data):
        logits = model.apply(params, data.predicates)
        return jnp.mean(cross_entropy(data.targets, logits))

    @jax.jit
    def update(data, params, opt_state):
        batch_loss, dparam = jax.value_and_grad(loss)(params, data)

        updates, opt_state = optimizer.update(dparam, opt_state)
        params = optix.apply_updates(params, updates)
        return batch_loss, OptimizationState(params, opt_state)

    rng = jax.random.PRNGKey(0)
    init_params = model.init(
        rng,
        preprocess_batch(next(data_generator)).predicates,
    )
    state = OptimizationState(init_params, optimizer.init(init_params))

    for i, batch in enumerate(data_generator):
        if i >= max_iter:
            break

        batch = preprocess_batch(batch)
        batch_loss, state = update(batch, *state)

        if (i + 1) % 100 == 0:
            print(i + 1, evaluate(state.params))
Beispiel #24
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.PolicyValueRNN,
        initial_state_fn: Callable[[], networks.RNNState],
        sequence_length: int,
        sequence_period: int,
        counter: counting.Counter = None,
        logger: loggers.Logger = None,
        discount: float = 0.99,
        max_queue_size: int = 100000,
        batch_size: int = 16,
        learning_rate: float = 1e-3,
        entropy_cost: float = 0.01,
        baseline_cost: float = 0.5,
        seed: int = 0,
        max_abs_reward: float = np.inf,
        max_gradient_norm: float = np.inf,
    ):

        num_actions = environment_spec.actions.num_values
        self._logger = logger or loggers.TerminalLogger('agent')
        queue = reverb.Table.queue(name=adders.DEFAULT_PRIORITY_TABLE,
                                   max_size=max_queue_size)
        self._server = reverb.Server([queue], port=None)
        self._can_sample = lambda: queue.can_sample(batch_size)
        address = f'localhost:{self._server.port}'

        # Component to add things into replay.
        adder = adders.SequenceAdder(
            client=reverb.Client(address),
            period=sequence_period,
            sequence_length=sequence_length,
        )

        # The dataset object to learn from.
        extra_spec = {
            'core_state': hk.transform(initial_state_fn).apply(None),
            'logits': np.ones(shape=(num_actions, ), dtype=np.float32)
        }
        # Remove batch dimensions.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            extra_spec=extra_spec,
            sequence_length=sequence_length)

        rng = hk.PRNGSequence(seed)

        optimizer = optix.chain(
            optix.clip_by_global_norm(max_gradient_norm),
            optix.adam(learning_rate),
        )
        self._learner = learning.IMPALALearner(
            obs_spec=environment_spec.observations,
            network=network,
            initial_state_fn=initial_state_fn,
            iterator=dataset.as_numpy_iterator(),
            rng=rng,
            counter=counter,
            logger=logger,
            optimizer=optimizer,
            discount=discount,
            entropy_cost=entropy_cost,
            baseline_cost=baseline_cost,
            max_abs_reward=max_abs_reward,
        )

        variable_client = jax_variable_utils.VariableClient(self._learner,
                                                            key='policy')
        self._actor = acting.IMPALAActor(
            network=network,
            initial_state_fn=initial_state_fn,
            rng=rng,
            adder=adder,
            variable_client=variable_client,
        )
Beispiel #25
0
def main(_):
    # Make the network and optimiser.
    net = hk.transform(net_fn)
    opt = optix.adam(1e-3)

    # Training loss (cross-entropy).
    def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
        """Compute the loss of the network, including L2."""
        logits = net.apply(params, batch)
        labels = jax.nn.one_hot(batch["label"], 10)

        l2_loss = 0.5 * sum(
            jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
        softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
        softmax_xent /= labels.shape[0]

        return softmax_xent + 1e-4 * l2_loss

    # Evaluation metric (classification accuracy).
    @jax.jit
    def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray:
        predictions = net.apply(params, batch)
        return jnp.mean(jnp.argmax(predictions, axis=-1) == batch["label"])

    @jax.jit
    def update(params: hk.Params, opt_state: OptState,
               batch: Batch) -> Tuple[hk.Params, OptState]:
        """Learning rule (stochastic gradient descent)."""
        grads = jax.grad(loss)(params, batch)
        updates, opt_state = opt.update(grads, opt_state)
        new_params = optix.apply_updates(params, updates)
        return new_params, opt_state

    # We maintain avg_params, the exponential moving average of the "live" params.
    # avg_params is used only for evaluation.
    # For more, see: https://doi.org/10.1137/0330046
    @jax.jit
    def ema_update(avg_params: hk.Params,
                   new_params: hk.Params,
                   epsilon: float = 0.001) -> hk.Params:
        return jax.tree_multimap(
            lambda p1, p2: (1 - epsilon) * p1 + epsilon * p2, avg_params,
            new_params)

    # Make datasets.
    train = load_dataset("train", is_training=True, batch_size=1000)
    train_eval = load_dataset("train", is_training=False, batch_size=10000)
    test_eval = load_dataset("test", is_training=False, batch_size=10000)

    # Initialize network and optimiser; note we draw an input to get shapes.
    params = avg_params = net.init(jax.random.PRNGKey(42), next(train))
    opt_state = opt.init(params)

    # Train/eval loop.
    for step in range(10001):
        if step % 1000 == 0:
            # Periodically evaluate classification accuracy on train & test sets.
            train_accuracy = accuracy(avg_params, next(train_eval))
            test_accuracy = accuracy(avg_params, next(test_eval))
            train_accuracy, test_accuracy = jax.device_get(
                (train_accuracy, test_accuracy))
            print(f"[Step {step}] Train / Test accuracy: "
                  f"{train_accuracy:.3f} / {test_accuracy:.3f}.")

        # Do SGD on a batch of training examples.
        params, opt_state = update(params, opt_state, next(train))
        avg_params = ema_update(avg_params, params)
Beispiel #26
0
    def __init__(
        self,
        num_agent_steps,
        state_space,
        action_space,
        seed,
        max_grad_norm=None,
        gamma=0.99,
        nstep=1,
        buffer_size=10 ** 6,
        use_per=False,
        batch_size=32,
        start_steps=50000,
        update_interval=4,
        update_interval_target=8000,
        eps=0.01,
        eps_eval=0.001,
        eps_decay_steps=250000,
        loss_type="huber",
        dueling_net=False,
        double_q=False,
        setup_net=True,
        fn=None,
        lr=5e-5,
        lr_cum_p=2.5e-9,
        units=(512,),
        num_quantiles=32,
        num_cosines=64,
    ):
        super(FQF, self).__init__(
            num_agent_steps=num_agent_steps,
            state_space=state_space,
            action_space=action_space,
            seed=seed,
            max_grad_norm=max_grad_norm,
            gamma=gamma,
            nstep=nstep,
            buffer_size=buffer_size,
            batch_size=batch_size,
            use_per=use_per,
            start_steps=start_steps,
            update_interval=update_interval,
            update_interval_target=update_interval_target,
            eps=eps,
            eps_eval=eps_eval,
            eps_decay_steps=eps_decay_steps,
            loss_type=loss_type,
            dueling_net=dueling_net,
            double_q=double_q,
            setup_net=False,
            num_quantiles=num_quantiles,
        )
        if setup_net:
            if fn is None:

                def fn(s, cum_p):
                    return DiscreteImplicitQuantileFunction(
                        num_cosines=num_cosines,
                        action_space=action_space,
                        hidden_units=units,
                        dueling_net=dueling_net,
                    )(s, cum_p)

            self.net, self.params, fake_feature = make_quantile_nerwork(self.rng, state_space, action_space, fn, num_quantiles)
            self.params_target = self.params
            opt_init, self.opt = optix.adam(lr, eps=0.01 / batch_size)
            self.opt_state = opt_init(self.params)

        # Fraction proposal network.
        self.cum_p_net = hk.without_apply_rng(hk.transform(lambda s: CumProbNetwork(num_quantiles=num_quantiles)(s)))
        self.params_cum_p = self.cum_p_net.init(next(self.rng), fake_feature)
        opt_init, self.opt_cum_p = optix.rmsprop(lr_cum_p, decay=0.95, eps=1e-5, centered=True)
        self.opt_state_cum_p = opt_init(self.params_cum_p)
Beispiel #27
0
    def __init__(
        self,
        num_agent_steps,
        state_space,
        action_space,
        seed,
        max_grad_norm=None,
        gamma=0.99,
        nstep=1,
        buffer_size=10 ** 6,
        use_per=False,
        batch_size=32,
        start_steps=50000,
        update_interval=4,
        update_interval_target=8000,
        eps=0.01,
        eps_eval=0.001,
        eps_decay_steps=250000,
        loss_type="huber",
        dueling_net=False,
        double_q=False,
        setup_net=True,
        fn=None,
        lr=2.5e-4,
        units=(512,),
    ):
        super(DQN, self).__init__(
            num_agent_steps=num_agent_steps,
            state_space=state_space,
            action_space=action_space,
            seed=seed,
            max_grad_norm=max_grad_norm,
            gamma=gamma,
            nstep=nstep,
            buffer_size=buffer_size,
            batch_size=batch_size,
            use_per=use_per,
            start_steps=start_steps,
            update_interval=update_interval,
            update_interval_target=update_interval_target,
            eps=eps,
            eps_eval=eps_eval,
            eps_decay_steps=eps_decay_steps,
            loss_type=loss_type,
            dueling_net=dueling_net,
            double_q=double_q,
        )
        if setup_net:
            if fn is None:

                def fn(s):
                    return DiscreteQFunction(
                        action_space=action_space,
                        hidden_units=units,
                        dueling_net=dueling_net,
                    )(s)

            self.net = hk.without_apply_rng(hk.transform(fn))
            self.params = self.params_target = self.net.init(next(self.rng), *self.fake_args)
            opt_init, self.opt = optix.adam(lr, eps=0.01 / batch_size)
            self.opt_state = opt_init(self.params)
Beispiel #28
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks.QNetwork,
        batch_size: int = 256,
        prefetch_size: int = 4,
        target_update_period: int = 100,
        samples_per_insert: float = 32.0,
        min_replay_size: int = 1000,
        max_replay_size: int = 1000000,
        importance_sampling_exponent: float = 0.2,
        priority_exponent: float = 0.6,
        n_step: int = 5,
        epsilon: float = 0.,
        learning_rate: float = 1e-3,
        discount: float = 0.99,
    ):
        """Initialize the agent."""

        # Create a replay server to add data to. This uses no limiter behavior in
        # order to allow the Agent interface to handle it.
        replay_table = reverb.Table(
            name=adders.DEFAULT_PRIORITY_TABLE,
            sampler=reverb.selectors.Prioritized(priority_exponent),
            remover=reverb.selectors.Fifo(),
            max_size=max_replay_size,
            rate_limiter=reverb.rate_limiters.MinSize(1))
        self._server = reverb.Server([replay_table], port=None)

        # The adder is used to insert observations into replay.
        address = f'localhost:{self._server.port}'
        adder = adders.NStepTransitionAdder(client=reverb.Client(address),
                                            n_step=n_step,
                                            discount=discount)

        # The dataset provides an interface to sample from replay.
        dataset = datasets.make_reverb_dataset(
            client=reverb.TFClient(address),
            environment_spec=environment_spec,
            batch_size=batch_size,
            prefetch_size=prefetch_size,
            transition_adder=True)

        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = hk.transform(network).apply(params, observation)
            return rlax.epsilon_greedy(epsilon).sample(key, action_values)

        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(1),
            optimizer=optix.adam(learning_rate),
            discount=discount,
            importance_sampling_exponent=importance_sampling_exponent,
            target_update_period=target_update_period,
            iterator=dataset.as_numpy_iterator(),
            replay_client=reverb.Client(address),
        )

        variable_client = variable_utils.VariableClient(learner, 'foo')

        actor = actors.FeedForwardActor(policy=policy,
                                        rng=hk.PRNGSequence(1),
                                        variable_client=variable_client,
                                        adder=adder)

        super().__init__(actor=actor,
                         learner=learner,
                         min_observations=max(batch_size, min_replay_size),
                         observations_per_step=float(batch_size) /
                         samples_per_insert)
Beispiel #29
0
def main_loop(unused_arg):
    env = catch.Catch(seed=FLAGS.seed)
    rng = hk.PRNGSequence(jax.random.PRNGKey(FLAGS.seed))

    # Build and initialize Q-network.
    num_actions = env.action_spec().num_values
    network = build_network(num_actions)
    sample_input = env.observation_spec().generate_value()
    net_params = network.init(next(rng), sample_input)

    # Build and initialize optimizer.
    optimizer = optix.adam(FLAGS.learning_rate)
    opt_state = optimizer.init(net_params)

    @jax.jit
    def policy(net_params, key, obs):
        """Sample action from epsilon-greedy policy."""
        q = network.apply(net_params, obs)
        a = rlax.epsilon_greedy(epsilon=FLAGS.epsilon).sample(key, q)
        return q, a

    @jax.jit
    def eval_policy(net_params, key, obs):
        """Sample action from greedy policy."""
        q = network.apply(net_params, obs)
        return rlax.greedy().sample(key, q)

    @jax.jit
    def update(net_params, opt_state, obs_tm1, a_tm1, r_t, discount_t, q_t):
        """Update network weights wrt Q-learning loss."""
        def q_learning_loss(net_params, obs_tm1, a_tm1, r_t, discount_t, q_t):
            q_tm1 = network.apply(net_params, obs_tm1)
            td_error = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
            return rlax.l2_loss(td_error)

        dloss_dtheta = jax.grad(q_learning_loss)(net_params, obs_tm1, a_tm1,
                                                 r_t, discount_t, q_t)
        updates, opt_state = optimizer.update(dloss_dtheta, opt_state)
        net_params = optix.apply_updates(net_params, updates)
        return net_params, opt_state

    print(f"Training agent for {FLAGS.train_episodes} episodes")
    print("Returns range [-1.0, 1.0]")
    for episode in range(FLAGS.train_episodes):
        timestep = env.reset()
        obs_tm1 = timestep.observation

        _, a_tm1 = policy(net_params, next(rng), obs_tm1)

        while not timestep.last():
            new_timestep = env.step(int(a_tm1))
            obs_t = new_timestep.observation

            # Sample action from stochastic policy.
            q_t, a_t = policy(net_params, next(rng), obs_t)

            # Update Q-values.
            r_t = new_timestep.reward
            discount_t = FLAGS.discount_factor * new_timestep.discount
            net_params, opt_state = update(net_params, opt_state, obs_tm1,
                                           a_tm1, r_t, discount_t, q_t)

            timestep = new_timestep
            obs_tm1 = obs_t
            a_tm1 = a_t

        if not episode % FLAGS.evaluate_every:
            # Evaluate agent with deterministic policy.
            returns = 0.
            for _ in range(FLAGS.eval_episodes):
                timestep = env.reset()
                obs = timestep.observation

                while not timestep.last():
                    action = eval_policy(net_params, next(rng), obs)
                    timestep = env.step(int(action))
                    obs = timestep.observation
                    returns += timestep.reward

            avg_returns = returns / FLAGS.eval_episodes
            print(f"Episode {episode:4d}: Average returns: {avg_returns:.2f}")
Beispiel #30
0
def main(_):
    # Make the network and optimiser.
    net = hk.transform(net_fn)
    opt = optix.adam(1e-3)

    # Define layerwise sparsities
    def module_matching(s):
        def match_func(m, n, k):
            return m.endswith(s) and not sparsity_ignore(m, n, k)

        return match_func

    module_sparsity = ((module_matching("linear"), 0.98),
                       (module_matching("linear_1"), 0.9))

    # Training loss (cross-entropy).
    @jax.jit
    def loss(params: hk.Params, batch: Batch) -> jnp.ndarray:
        """Compute the loss of the network, including L2."""
        logits = net.apply(params, batch)
        labels = jax.nn.one_hot(batch["label"], 10)

        l2_loss = 0.5 * sum(
            jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
        softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
        softmax_xent /= labels.shape[0]

        return softmax_xent + 1e-4 * l2_loss

    # Evaluation metric (classification accuracy).
    @jax.jit
    def accuracy(params: hk.Params, batch: Batch) -> jnp.ndarray:
        predictions = net.apply(params, batch)
        return jnp.mean(jnp.argmax(predictions, axis=-1) == batch["label"])

    @jax.jit
    def get_updates(
        params: hk.Params,
        opt_state: OptState,
        batch: Batch,
    ) -> Tuple[hk.Params, OptState]:
        """Learning rule (stochastic gradient descent)."""
        grads = jax.grad(loss)(params, batch)
        updates, opt_state = opt.update(grads, opt_state)
        return updates, opt_state

    # We maintain avg_params, the exponential moving average of the "live" params.
    # avg_params is used only for evaluation.
    # For more, see: https://doi.org/10.1137/0330046
    @jax.jit
    def ema_update(
        avg_params: hk.Params,
        new_params: hk.Params,
        epsilon: float = 0.001,
    ) -> hk.Params:
        return jax.tree_multimap(
            lambda p1, p2: (1 - epsilon) * p1 + epsilon * p2, avg_params,
            new_params)

    # Make datasets.
    train = load_dataset("train", is_training=True, batch_size=1000)
    train_eval = load_dataset("train", is_training=False, batch_size=10000)
    test_eval = load_dataset("test", is_training=False, batch_size=10000)

    # Implemenation note: It is possible to avoid pruned_params and just use
    # a single params which progressively gets pruned.  The updates also don't
    # need to masked in such an implementation.  The current implementation
    # attempts to mimic the way the current TF implementation which allows for
    # previously inactivated connections to become active again if active values
    # drop below their value.

    # Initialize network and optimiser; note we draw an input to get shapes.
    pruned_params = params = avg_params = net.init(jax.random.PRNGKey(42),
                                                   next(train))

    masks = update_mask(params, 0., module_sparsity)
    opt_state = opt.init(params)

    # Train/eval loop.
    for step in range(10001):
        if step % 1000 == 0:
            # Periodically evaluate classification accuracy on train & test sets.
            avg_params = apply_mask(avg_params, masks, module_sparsity)
            train_accuracy = accuracy(avg_params, next(train_eval))
            test_accuracy = accuracy(avg_params, next(test_eval))
            total_params, total_nnz, per_layer_sparsities = get_sparsity(
                avg_params)
            train_accuracy, test_accuracy, total_nnz, per_layer_sparsities = (
                jax.device_get((train_accuracy, test_accuracy, total_nnz,
                                per_layer_sparsities)))
            print(f"[Step {step}] Train / Test accuracy: "
                  f"{train_accuracy:.3f} / {test_accuracy:.3f}.")
            print(f"Non-zero params / Total: {total_nnz} / {total_params}; "
                  f"Total Sparsity: {1. - total_nnz / total_params:.3f}")

        # Do SGD on a batch of training examples.
        pruned_params = apply_mask(params, masks, module_sparsity)
        updates, opt_state = get_updates(pruned_params, opt_state, next(train))
        # applying a straight-through estimator here (that is not masking
        # the updates) leads to much worse performance.
        updates = apply_mask(updates, masks, module_sparsity)
        params = optix.apply_updates(params, updates)
        # we start pruning at iteration 1000 and end at iteration 8000
        progress = min(max((step - 1000.) / 8000., 0.), 1.)
        if step % 200 == 0:
            sparsity_fraction = zhugupta_func(progress)
            masks = update_mask(params, sparsity_fraction, module_sparsity)
        avg_params = ema_update(avg_params, params)
    print(per_layer_sparsities)