コード例 #1
0
def create_logistic_model(
    only_digits: bool = False,
    reg_fn: Optional[Callable[[core.Params],
                              jnp.ndarray]] = None) -> core.Model:
  """Creates EMNIST logistic model."""
  num_classes = 10 if only_digits else 62

  def forward_pass(batch):
    network = hk.Sequential([
        hk.Flatten(),
        hk.Linear(num_classes),
    ])
    return network(batch['x'])

  transformed_forward_pass = hk.transform(forward_pass)
  return core.create_model_from_haiku(
      transformed_forward_pass=transformed_forward_pass,
      sample_batch=_EMNIST_HAIKU_SAMPLE_BATCH,
      loss_fn=_EMNIST_LOSS_FN,
      reg_fn=reg_fn,
      metrics_fn_map=_EMNIST_METRICS_FN_MAP)
コード例 #2
0
ファイル: ibp_test.py プロジェクト: zeta1999/jax_verify
    def test_linear_ibp(self):
        def linear_model(inp):
            return hk.Linear(1)(inp)

        z = jnp.array([[1., 2., 3.]])
        params = {
            'linear': {
                'w': jnp.ones((3, 1), dtype=jnp.float32),
                'b': jnp.array([2.])
            }
        }

        fun = functools.partial(
            hk.without_apply_rng(hk.transform(linear_model,
                                              apply_rng=True)).apply, params)
        input_bounds = jax_verify.IntervalBound(z - 1., z + 1.)
        output_bounds = jax_verify.interval_bound_propagation(
            fun, input_bounds)

        self.assertAlmostEqual(5., output_bounds.lower)
        self.assertAlmostEqual(11., output_bounds.upper)
コード例 #3
0
ファイル: transformed_test.py プロジェクト: deepmind/distrax
  def test_bijector_that_assumes_batch_dimensions(self):
    # Create a Haiku conditioner that assumes a single batch dimension.
    def forward(x):
      network = hk.Sequential([hk.Flatten(preserve_dims=1), hk.Linear(3)])
      return network(x)
    init, apply = hk.transform(forward)
    params = init(self.seed, jnp.ones((2, 3)))
    conditioner = functools.partial(apply, params, self.seed)

    bijector = masked_coupling.MaskedCoupling(
        jnp.ones(3) > 0, conditioner, tfb.Scale)

    base = tfd.MultivariateNormalDiag(jnp.zeros((2, 3)), jnp.ones((2, 3)))
    dist = transformed.Transformed(base, bijector)
    # Exercise the trace-based functions
    assert dist.batch_shape == (2,)
    assert dist.event_shape == (3,)
    assert dist.dtype == jnp.float32
    sample = self.variant(dist.sample)(seed=self.seed)
    assert sample.dtype == dist.dtype
    self.variant(dist.log_prob)(sample)
コード例 #4
0
def test_torch_to_jax():
    x = np.random.randn(250, 66).astype(np.float32)

    net_torch = FlexibleNeRFModelTorch()
    net_jax = hk.without_apply_rng(
        hk.transform(jax.jit(lambda x: FlexibleNeRFModel()(x))))

    jax_params = torch_to_jax(dict(net_torch.named_parameters()),
                              "flexible_ne_rf_model")

    jax_out = net_jax.apply(jax_params, jnp.array(x))
    torch_out = net_torch(torch.from_numpy(x))

    assert np.allclose(torch_out.detach().numpy(),
                       np.array(jax_out),
                       atol=1e-7)

    # now let's verify that the gradients are correct
    jax_fn = lambda x, p: net_jax.apply(p, x).flatten().sum()

    jax_params_grad = jit(grad(jax_fn, argnums=(1, )))(jnp.array(x),
                                                       jax_params)[0]

    torch_loss = torch_out.flatten().sum()
    torch_loss.backward()

    torch_grads = torch_to_jax(
        {k: v.grad
         for k, v in net_torch.named_parameters()}, "flexible_ne_rf_model")

    def recursive_compare(d1, d2):
        assert (d1.keys() == d2.keys())
        for key in d1.keys():
            if isinstance(d1[key], dict):
                assert isinstance(d2[key], dict)
                recursive_compare(d1[key], d2[key])
            else:
                assert np.allclose(d1[key], d2[key], rtol=1e-3, atol=1e-7)

    recursive_compare(jax_params_grad, torch_grads)
コード例 #5
0
ファイル: emnist.py プロジェクト: google/fedjax
def create_dense_model(only_digits: bool = False,
                       hidden_units: int = 200) -> models.Model:
    """Creates EMNIST dense net with haiku."""
    num_classes = 10 if only_digits else 62

    def forward_pass(batch):
        network = hk.Sequential([
            hk.Flatten(),
            hk.Linear(hidden_units),
            jax.nn.relu,
            hk.Linear(hidden_units),
            jax.nn.relu,
            hk.Linear(num_classes),
        ])
        return network(batch['x'])

    transformed_forward_pass = hk.transform(forward_pass)
    return models.create_model_from_haiku(
        transformed_forward_pass=transformed_forward_pass,
        sample_batch=_HAIKU_SAMPLE_BATCH,
        train_loss=_TRAIN_LOSS,
        eval_metrics=_EVAL_METRICS)
コード例 #6
0
def test_unvectorize_single_output(rngs, x_batch, x_single):
    def f_batch(X):
        return hk.Linear(11)(X)

    init, f_batch = hk.transform(f_batch)
    params = init(next(rngs), x_batch)
    y_batch = f_batch(params, next(rngs), x_batch)
    assert y_batch.shape == (7, 11)

    f_single = unvectorize(f_batch, in_axes=(None, None, 0), out_axes=0)
    y_single = f_single(params, next(rngs), x_single)
    assert y_single.shape == (11, )

    f_single = unvectorize(f_batch, in_axes=(None, None, 0), out_axes=(0, ))
    msg = r"out_axes must be an int for functions with a single output; got: out_axes=\(0,\)"
    with pytest.raises(TypeError, match=msg):
        f_single(params, next(rngs), x_single)

    f_single = unvectorize(f_batch, in_axes=(None, None, 0, 0), out_axes=(0, ))
    msg = r"number of in_axes must match the number of function inputs"
    with pytest.raises(ValueError, match=msg):
        f_single(params, next(rngs), x_single)
コード例 #7
0
def make_policy_prior_network(
    spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Tuple[int, ...] = (64, 64)
) -> networks.FeedForwardNetwork:
    """Creates a policy prior network used by the agent."""

    action_size = np.prod(spec.actions.shape, dtype=int)

    def _policy_prior_fn(observation_t,
                         action_tm1,
                         is_training=False,
                         key=None):
        # is_training and key allows to defined train/test dependant modules
        # like dropout.
        del is_training
        del key
        network = hk.nets.MLP(hidden_layer_sizes + (action_size, ))
        # Policy prior returns an action.
        return network(jnp.concatenate([observation_t, action_tm1], axis=-1))

    policy_prior = hk.without_apply_rng(hk.transform(_policy_prior_fn))
    return make_network_from_module(policy_prior, spec)
コード例 #8
0
  def test_summarize_model(self):

    def model_fun(x):
      """A model with two submodules."""

      class Alpha(hk.Module):  # Alpha submodule.

        def __call__(self, x):
          return hk.Sequential([
              hk.Conv2D(8, (3, 3)), jax.nn.relu,
              hk.MaxPool((1, 2, 2, 1), (1, 2, 2, 1), 'VALID'),
              hk.Flatten(),
              hk.Linear(3, with_bias=False)
          ])(x)

      class Beta(hk.Module):  # Beta submodule.

        def __call__(self, x):
          return hk.Sequential([hk.Flatten(), hk.Linear(3), jax.nn.relu])(x)

      return hk.Linear(1)(Alpha()(x) + Beta()(x))

    model = hk.transform(model_fun)
    x = np.random.randn(1, 12, 15, 1)
    params = model.init(jax.random.PRNGKey(0), x)

    summary = hk_util.summarize_model(params)
    self.assertEqual(
        summary, """
Variable         Shape            #
alpha/conv2_d.b  (8,)             8
alpha/conv2_d.w  (3, 3, 1, 8)    72
alpha/linear.w   (336, 3)      1008
beta/linear.b    (3,)             3
beta/linear.w    (180, 3)       540
linear.b         (1,)             1
linear.w         (3, 1)           3
Total                          1635
""".strip())
コード例 #9
0
ファイル: second_order_test.py プロジェクト: stjordanis/optax
    def setUp(self):
        super().setUp()

        self.data = np.random.rand(NUM_SAMPLES, NUM_FEATURES)
        self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES)

        def net_fn(z):
            mlp = hk.Sequential(
                [hk.Linear(10), jax.nn.relu,
                 hk.Linear(NUM_CLASSES)],
                name='mlp')
            return jax.nn.log_softmax(mlp(z))

        net = hk.without_apply_rng(hk.transform(net_fn, apply_rng=True))
        self.parameters = net.init(jax.random.PRNGKey(0), self.data)

        def loss(params, inputs, targets):
            log_probs = net.apply(params, inputs)
            return -jnp.mean(hk.one_hot(targets, NUM_CLASSES) * log_probs)

        self.loss_fn = loss

        def jax_hessian_diag(loss_fun, params, inputs, targets):
            """This is the 'ground-truth' obtained via the JAX library."""
            hess = jax.hessian(loss_fun)(params, inputs, targets)

            # Extracts the diagonal components.
            hess_diag = collections.defaultdict(dict)
            for k0, k1 in itertools.product(params.keys(), ['w', 'b']):
                params_shape = params[k0][k1].shape
                n_params = np.prod(params_shape)
                hess_diag[k0][k1] = jnp.diag(hess[k0][k1][k0][k1].reshape(
                    n_params, n_params)).reshape(params_shape)
            for k, v in hess_diag.items():
                hess_diag[k] = v
            return second_order.ravel(hess_diag)

        self.hessian = jax_hessian_diag(self.loss_fn, self.parameters,
                                        self.data, self.labels)
コード例 #10
0
    def __init__(
        self,
        forward_fn: PolicyValueFn,
        initial_state_fn: Callable[[], hk.LSTMState],
        rng: hk.PRNGSequence,
        variable_client: Optional[variable_utils.VariableClient] = None,
        adder: Optional[adders.Adder] = None,
    ):

        # Store these for later use.
        self._adder = adder
        self._variable_client = variable_client
        self._forward = forward_fn
        self._rng = rng

        # Make sure not to use a random policy after checkpoint restoration by
        # assigning variables before running the environment loop.
        if self._variable_client is not None:
            self._variable_client.update_and_wait()

        self._initial_state = hk.without_apply_rng(
            hk.transform(initial_state_fn, apply_rng=True)).apply(None)
コード例 #11
0
ファイル: networks.py プロジェクト: andnp/rl-control-template
def getNetwork(inputs: Tuple, outputs: int, params: Dict[str, Any], seed: int):
    name = params['type']

    if name == 'TwoLayerRelu':
        hidden = params['hidden']
        layers = [hidden, hidden]

        network = partial(nn, layers, outputs)

    elif name == 'OneLayerRelu':
        hidden = params['hidden']
        layers = [hidden]

        network = partial(nn, layers, outputs)

    elif name == 'MinatarNet':

        def conv(x):
            hidden = hk.Sequential([
                hk.Conv2D(16, 3, 2),
                jax.nn.relu,
                hk.Flatten(),
            ])

            return hidden(x)

        hidden = params['hidden']
        layers = [hidden]
        network = pipe([conv, partial(nn, layers, outputs)])

    else:
        raise NotImplementedError()

    network = hk.without_apply_rng(hk.transform(network))
    net_params = network.init(jax.random.PRNGKey(seed),
                              jnp.zeros((1, ) + tuple(inputs)))

    return network, net_params
コード例 #12
0
ファイル: actors_test.py プロジェクト: vishalbelsare/acme
    def test_feedforward(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)

        def policy(inputs: jnp.ndarray):
            action_values = hk.Sequential([
                hk.Flatten(),
                hk.Linear(env_spec.actions.num_values),
            ])(inputs)
            action = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return action, (action_values, )
            else:
                return action

        policy = hk.transform(policy)

        rng = hk.PRNGSequence(1)
        dummy_obs = utils.add_batch_dim(utils.zeros_like(
            env_spec.observations))
        params = policy.init(next(rng), dummy_obs)

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        if has_extras:
            actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core(
                policy.apply)
        else:
            actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
                policy.apply)
        actor = actors.GenericActor(actor_core,
                                    random_key=jax.random.PRNGKey(1),
                                    variable_client=variable_client)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
コード例 #13
0
ファイル: train.py プロジェクト: deepmind/dm-haiku
def main(_):
    FLAGS.alsologtostderr = True  # Always log visibly.
    # Create the dataset.
    train_dataset = dataset.AsciiDataset(FLAGS.dataset_path, FLAGS.batch_size,
                                         FLAGS.sequence_length)
    vocab_size = train_dataset.vocab_size

    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads,
                                  FLAGS.num_layers, FLAGS.dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.grad_clip_value),
                            optax.adam(FLAGS.learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)
    updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
        # We use JAX runahead to mask data preprocessing and JAX dispatch overheads.
        # Using values from state/metrics too often will block the runahead and can
        # cause these overheads to become more prominent.
        if step % LOG_EVERY == 0:
            steps_per_sec = LOG_EVERY / (time.time() - prev_time)
            prev_time = time.time()
            metrics.update({'steps_per_sec': steps_per_sec})
            logging.info({k: float(v) for k, v in metrics.items()})
コード例 #14
0
def make_world_model_network(
    spec: specs.EnvironmentSpec,
    hidden_layer_sizes: Tuple[int, ...] = (64, 64)
) -> networks.FeedForwardNetwork:
    """Creates a world model network used by the agent."""

    observation_size = np.prod(spec.observations.shape, dtype=int)

    def _world_model_fn(observation_t, action_t, is_training=False, key=None):
        # is_training and key allows to defined train/test dependant modules
        # like dropout.
        del is_training
        del key
        network = hk.nets.MLP(hidden_layer_sizes + (observation_size + 1, ))
        # World model returns both an observation and a reward.
        observation_tp1, reward_t = jnp.split(network(
            jnp.concatenate([observation_t, action_t], axis=-1)),
                                              [observation_size],
                                              axis=-1)
        return observation_tp1, reward_t

    world_model = hk.without_apply_rng(hk.transform(_world_model_fn))
    return make_network_from_module(world_model, spec)
コード例 #15
0
  def test_model_workflow(self):
    meta = FooMetadata(hidden_units=[5, 2])
    model = hk.transform(functools.partial(foo_model, meta=meta))

    # Get some random param values.
    batch = {'x': jnp.array([[0.5, 1.0, -1.5]])}
    params = model.init(jax.random.PRNGKey(0), batch)

    # Associate params with the model to get a TrainedModel.
    trained_model = hk_util.TrainedModel(model, meta=meta, params=params)

    # Save and load the model.
    filename = '/tmp/hk_util_test/model.pkl'
    trained_model.save(filename)

    recovered = hk_util.TrainedModel.load(filename, foo_model, FooMetadata)

    # Check that meta, params, and model forward function are the same.
    self.assertEqual(recovered.meta, meta)
    self._assert_tree_equal(recovered.params, params)
    y = recovered(batch)
    expected_y = model.apply(params, batch)
    np.testing.assert_array_equal(y, expected_y)
コード例 #16
0
def make_network(
        spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork:
    """Creates networks used by the agent."""
    num_actions = spec.actions.num_values

    def actor_fn(obs, is_training=True, key=None):
        # is_training and key allows to utilize train/test dependant modules
        # like dropout.
        del is_training
        del key
        mlp = hk.Sequential([hk.Flatten(), hk.nets.MLP([64, 64, num_actions])])
        return mlp(obs)

    policy = hk.without_apply_rng(hk.transform(actor_fn))

    # Create dummy observations to create network parameters.
    dummy_obs = utils.zeros_like(spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)

    network = networks_lib.FeedForwardNetwork(
        lambda key: policy.init(key, dummy_obs), policy.apply)

    return network
コード例 #17
0
def default_agent(
    obs_spec: specs.Array,
    action_spec: specs.DiscreteArray,
    seed: int = 0,
    num_ensemble: int = 20,
) -> BootstrappedDqn:
    """Initialize a Bootstrapped DQN agent with default parameters."""

    # Define network.
    prior_scale = 3.
    hidden_sizes = [50, 50]

    def network(inputs: jnp.ndarray) -> jnp.ndarray:
        """Simple Q-network with randomized prior function."""
        net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
        prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values])
        x = hk.Flatten()(inputs)
        return net(x) + prior_scale * lax.stop_gradient(prior_net(x))

    optimizer = optix.adam(learning_rate=1e-3)
    return BootstrappedDqn(
        obs_spec=obs_spec,
        action_spec=action_spec,
        network=hk.transform(network),
        batch_size=128,
        discount=.99,
        num_ensemble=num_ensemble,
        replay_capacity=10000,
        min_replay_size=128,
        sgd_period=1,
        target_update_period=4,
        optimizer=optimizer,
        mask_prob=0.5,
        noise_scale=0.,
        epsilon_fn=lambda _: 0.,
        seed=seed,
    )
コード例 #18
0
def test_run_network():
    x = np.random.randn(250, 66).astype(np.float32)

    net_torch = FlexibleNeRFModelTorch()
    net_jax = hk.without_apply_rng(
        hk.transform(jax.jit(lambda x: FlexibleNeRFModel()(x))))

    jax_params = torch_to_jax(dict(net_torch.named_parameters()),
                              "flexible_ne_rf_model")

    jax_out = net_jax.apply(jax_params, jnp.array(x))
    torch_out = net_torch(torch.from_numpy(x))

    pts_np = np.random.random((256, 128, 3)).astype(np.float32)
    ray_batch_np = np.random.random((256, 11)).astype(np.float32)

    pts_torch = torch.from_numpy(pts_np)
    ray_batch_torch = torch.from_numpy(ray_batch_np)

    pts_jax = jnp.array(pts_np)
    ray_batch_jax = jnp.array(ray_batch_np)

    torch_result = run_network_torch(
        net_torch,
        pts_torch,
        ray_batch_torch,
        32,
        lambda p: positional_encoding_torch(p, 6),
        lambda p: positional_encoding_torch(p, 4),
    )
    jax_result = run_network(functools.partial(net_jax.apply, jax_params),
                             pts_jax, ray_batch_jax, 32, 6, 4)

    assert np.allclose(np.array(jax_result),
                       torch_result.detach().numpy(),
                       atol=1e-7)
    '''jax_fn = (
コード例 #19
0
ファイル: test_module.py プロジェクト: xidulu/numpyro
def test_random_module_mcmc(backend):
    if backend == "flax":
        import flax

        linear_module = flax.nn.Dense.partial(features=1)
        bias_name = "bias"
        weight_name = "kernel"
        random_module = random_flax_module
    elif backend == "haiku":
        import haiku as hk

        linear_module = hk.transform(lambda x: hk.Linear(1)(x))
        bias_name = "linear.b"
        weight_name = "linear.w"
        random_module = random_haiku_module

    def model(data, labels):
        nn = random_module("nn", linear_module,
                           prior={bias_name: dist.Cauchy(), weight_name: dist.Normal()},
                           input_shape=(dim,))
        logits = nn(data).squeeze(-1)
        numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels)

    N, dim = 3000, 3
    warmup_steps, num_samples = (1000, 1000)
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = np.arange(1., dim + 1.)
    logits = np.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    kernel = NUTS(model=model)
    mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False)
    mcmc.run(random.PRNGKey(2), data, labels)
    mcmc.print_summary()
    samples = mcmc.get_samples()
    assert set(samples.keys()) == {"nn/{}".format(bias_name), "nn/{}".format(weight_name)}
    assert_allclose(np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0), true_coefs, atol=0.22)
コード例 #20
0
ファイル: networks.py プロジェクト: vishalbelsare/acme
def make_continuous_networks(
        environment_spec: specs.EnvironmentSpec,
        policy_layer_sizes: Sequence[int] = (64, 64),
        value_layer_sizes: Sequence[int] = (64, 64),
) -> PPONetworks:
    """Creates PPONetworks to be used for continuous action environments."""

    # Get total number of action dimensions from action spec.
    num_dimensions = np.prod(environment_spec.actions.shape, dtype=int)

    def forward_fn(inputs):
        policy_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP(policy_layer_sizes, activation=jnp.tanh),
            # Note: we don't respect bounded action specs here and instead
            # rely on CanonicalSpecWrapper to clip actions accordingly.
            networks_lib.MultivariateNormalDiagHead(num_dimensions)
        ])
        value_network = hk.Sequential([
            utils.batch_concat,
            hk.nets.MLP(value_layer_sizes, activation=jnp.tanh),
            hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1)
        ])

        action_distribution = policy_network(inputs)
        value = value_network(inputs)
        return (action_distribution, value)

    # Transform into pure functions.
    forward_fn = hk.without_apply_rng(hk.transform(forward_fn))

    dummy_obs = utils.zeros_like(environment_spec.observations)
    dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
    network = networks_lib.FeedForwardNetwork(
        lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply)
    # Create PPONetworks to add functionality required by the agent.
    return make_ppo_networks(network)
コード例 #21
0
ファイル: test_utils.py プロジェクト: asmith26/jax_toolkit
    def test_supported_loss_returns_correctly_no_loss_kwargs(self):
        import haiku as hk

        def net_function(x: jnp.ndarray) -> jnp.ndarray:
            net = hk.Sequential([])
            return net(x)

        net_transform = hk.transform(net_function)
        actual_loss_function_wrapper = get_haiku_loss_function(
            net_transform, loss="mean_squared_error")

        # Check works
        rng = jax.random.PRNGKey(42)
        params = net_transform.init(rng, jnp.array(0))

        self.assertEqual(
            0, actual_loss_function_wrapper(params, jnp.array(0),
                                            jnp.array(0)))
        self.assertEqual(
            0, actual_loss_function_wrapper(params, jnp.array(1),
                                            jnp.array(1)))
        self.assertEqual(
            1, actual_loss_function_wrapper(params, jnp.array(0),
                                            jnp.array(1)))
コード例 #22
0
    def test_outputs_preserved(self):
        num_outputs = 2
        initial_state, update = pop_art.popart(num_outputs,
                                               step_size=1e-3,
                                               scale_lb=1e-6,
                                               scale_ub=1e6)
        state = initial_state()
        key = jax.random.PRNGKey(428)

        def net(x):
            linear = hk.Linear(num_outputs,
                               b_init=initializers.RandomUniform(),
                               name='head')
            return linear(x)

        init_fn, apply_fn = hk.without_apply_rng(hk.transform(net))
        key, subkey1, subkey2 = jax.random.split(key, 3)
        fixed_data = jax.random.uniform(subkey1, (4, 3))
        params = init_fn(subkey2, fixed_data)
        initial_result = apply_fn(params, fixed_data)
        indices = np.asarray([0, 1, 0, 1, 0, 1, 0, 1])
        # Repeatedly update state and verify that params still preserve outputs.
        for _ in range(30):
            key, subkey1, subkey2 = jax.random.split(key, 3)
            targets = jax.random.uniform(subkey1, (8, ))
            linear_params, state = update(params['head'], state, targets,
                                          indices)
            params = data_structures.to_mutable_dict(params)
            params['head'] = linear_params

            # Apply updated linear transformation and unnormalize outputs.
            transform = apply_fn(params, fixed_data)
            out = jnp.broadcast_to(
                state.scale, transform.shape) * transform + jnp.broadcast_to(
                    state.shift, transform.shape)
            np.testing.assert_allclose(initial_result, out, atol=1e-2)
コード例 #23
0
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
    """Initialize a DQN agent with default parameters."""
    def network(inputs: jnp.ndarray) -> jnp.ndarray:
        flat_inputs = hk.Flatten()(inputs)
        mlp = hk.nets.MLP([64, 64, action_spec.num_values])
        action_values = mlp(flat_inputs)
        return action_values

    return DQN(
        obs_spec=obs_spec,
        action_spec=action_spec,
        network=hk.transform(network),
        optimizer=optix.adam(1e-3),
        batch_size=32,
        discount=0.99,
        replay_capacity=10000,
        min_replay_size=100,
        sgd_period=1,
        target_update_period=4,
        epsilon=0.05,
        rng=hk.PRNGSequence(seed),
    )
コード例 #24
0
def make_q_network(spec,
                   hidden_layer_sizes=(512, 512, 256),
                   architecture='LayerNorm'):
  """DQN network for Aquadem algo."""

  def _q_fn(obs):
    if architecture == 'MLP':  # AQuaOff architecture
      network_fn = hk.nets.MLP
    elif architecture == 'LayerNorm':  # Original AQuaDem architecture
      network_fn = networks_lib.LayerNormMLP
    else:
      return ValueError('Architecture not recognized')

    network = network_fn(list(hidden_layer_sizes) + [spec.actions.num_values])
    value = network(obs)
    return value

  critic = hk.without_apply_rng(hk.transform(_q_fn))
  dummy_obs = utils.zeros_like(spec.observations)
  dummy_obs = utils.add_batch_dim(dummy_obs)

  critic_network = networks_lib.FeedForwardNetwork(
      lambda key: critic.init(key, dummy_obs), critic.apply)
  return critic_network
コード例 #25
0
  def __init__(
      self,
      can_run_backwards: bool,
      latent_system_dim: int,
      latent_system_net_type: str,
      latent_system_kwargs: Dict[str, Any],
      encoder_aggregation_type: Optional[str],
      decoder_de_aggregation_type: Optional[str],
      encoder_kwargs: Dict[str, Any],
      decoder_kwargs: Dict[str, Any],
      num_inference_steps: int,
      num_target_steps: int,
      name: str,
      latent_spatial_shape: Optional[Tuple[int, int]] = (4, 4),
      has_latent_transform: bool = False,
      latent_transform_kwargs: Optional[Dict[str, Any]] = None,
      rescale_by: Optional[str] = "pixels_and_time",
      data_format: str = "NHWC",
      **unused_kwargs
  ):
    # Arguments checks
    encoder_kwargs = encoder_kwargs or dict()
    decoder_kwargs = decoder_kwargs or dict()

    # Set the decoder de-aggregation type the "same" type as the encoder if not
    # provided
    if (decoder_de_aggregation_type is None and
        encoder_aggregation_type is not None):
      if encoder_aggregation_type == "linear_projection":
        decoder_de_aggregation_type = "linear_projection"
      elif encoder_aggregation_type in ("mean", "max"):
        decoder_de_aggregation_type = "tile"
      else:
        raise ValueError(f"Unrecognized encoder_aggregation_type="
                         f"{encoder_aggregation_type}")
    if latent_system_net_type == "conv":
      if encoder_aggregation_type is not None:
        raise ValueError("When the latent system is convolutional, the encoder "
                         "aggregation type should be None.")
      if decoder_de_aggregation_type is not None:
        raise ValueError("When the latent system is convolutional, the decoder "
                         "aggregation type should be None.")
    else:
      if encoder_aggregation_type is None:
        raise ValueError("When the latent system is not convolutional, the "
                         "you must provide an encoder aggregation type.")
      if decoder_de_aggregation_type is None:
        raise ValueError("When the latent system is not convolutional, the "
                         "you must provide an decoder aggregation type.")
    if has_latent_transform and latent_transform_kwargs is None:
      raise ValueError("When using latent transformation you have to provide "
                       "the latent_transform_kwargs argument.")
    if unused_kwargs:
      logging.warning("Unused kwargs: %s", str(unused_kwargs))
    super().__init__(**unused_kwargs)
    self.can_run_backwards = can_run_backwards
    self.latent_system_dim = latent_system_dim
    self.latent_system_kwargs = latent_system_kwargs
    self.latent_system_net_type = latent_system_net_type
    self.latent_spatial_shape = latent_spatial_shape
    self.num_inference_steps = num_inference_steps
    self.num_target_steps = num_target_steps
    self.rescale_by = rescale_by
    self.data_format = data_format
    self.name = name

    # Encoder
    self.encoder_kwargs = encoder_kwargs
    self.encoder = hk.transform(
        lambda *args, **kwargs: networks.SpatialConvEncoder(  # pylint: disable=unnecessary-lambda,g-long-lambda
            latent_dim=latent_system_dim,
            aggregation_type=encoder_aggregation_type,
            data_format=data_format,
            name="Encoder",
            **encoder_kwargs
        )(*args, **kwargs))

    # Decoder
    self.decoder_kwargs = decoder_kwargs
    self.decoder = hk.transform(
        lambda *args, **kwargs: networks.SpatialConvDecoder(  # pylint: disable=unnecessary-lambda,g-long-lambda
            initial_spatial_shape=self.latent_spatial_shape,
            de_aggregation_type=decoder_de_aggregation_type,
            data_format=data_format,
            max_de_aggregation_dims=self.latent_system_dim // 2,
            name="Decoder",
            **decoder_kwargs,
        )(*args, **kwargs))

    self.has_latent_transform = has_latent_transform
    if has_latent_transform:
      self.latent_transform = hk.transform(
          lambda *args, **kwargs: networks.make_flexible_net(  # pylint: disable=unnecessary-lambda,g-long-lambda
              net_type=latent_system_net_type,
              output_dims=latent_system_dim,
              name="LatentTransform",
              **latent_transform_kwargs
          )(*args, **kwargs))
    else:
      self.latent_transform = None

    self._jit_init = None
コード例 #26
0
def main(argv):
  """Trains Prioritized DQN agent on Atari."""
  del argv
  logging.info('Prioritized DQN on Atari on %s.',
               jax.lib.xla_bridge.get_backend().platform)
  random_state = np.random.RandomState(FLAGS.seed)
  rng_key = jax.random.PRNGKey(
      random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64))

  if FLAGS.results_csv_path:
    writer = parts.CsvWriter(FLAGS.results_csv_path)
  else:
    writer = parts.NullWriter()

  def environment_builder():
    """Creates Atari environment."""
    env = gym_atari.GymAtari(
        FLAGS.environment_name, seed=random_state.randint(1, 2**32))
    return gym_atari.RandomNoopsEnvironmentWrapper(
        env,
        min_noop_steps=1,
        max_noop_steps=30,
        seed=random_state.randint(1, 2**32),
    )

  env = environment_builder()

  logging.info('Environment: %s', FLAGS.environment_name)
  logging.info('Action spec: %s', env.action_spec())
  logging.info('Observation spec: %s', env.observation_spec())
  num_actions = env.action_spec().num_values
  network_fn = networks.double_dqn_atari_network(num_actions)
  network = hk.transform(network_fn)

  def preprocessor_builder():
    return processors.atari(
        additional_discount=FLAGS.additional_discount,
        max_abs_reward=FLAGS.max_abs_reward,
        resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
        num_action_repeats=FLAGS.num_action_repeats,
        num_pooled_frames=2,
        zero_discount_on_life_loss=True,
        num_stacked_frames=FLAGS.num_stacked_frames,
        grayscaling=True,
    )

  # Create sample network input from sample preprocessor output.
  sample_processed_timestep = preprocessor_builder()(env.reset())
  sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                          sample_processed_timestep)
  sample_network_input = sample_processed_timestep.observation
  chex.assert_shape(sample_network_input,
                    (FLAGS.environment_height, FLAGS.environment_width,
                     FLAGS.num_stacked_frames))

  exploration_epsilon_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity *
                  FLAGS.num_action_repeats),
      decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                      FLAGS.num_iterations * FLAGS.num_train_frames),
      begin_value=FLAGS.exploration_epsilon_begin_value,
      end_value=FLAGS.exploration_epsilon_end_value)

  # Note the t in the replay is not exactly aligned with the agent t.
  importance_sampling_exponent_schedule = parts.LinearSchedule(
      begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity),
      end_t=(FLAGS.num_iterations *
             int(FLAGS.num_train_frames / FLAGS.num_action_repeats)),
      begin_value=FLAGS.importance_sampling_exponent_begin_value,
      end_value=FLAGS.importance_sampling_exponent_end_value)

  if FLAGS.compress_state:

    def encoder(transition):
      return transition._replace(
          s_tm1=replay_lib.compress_array(transition.s_tm1),
          s_t=replay_lib.compress_array(transition.s_t))

    def decoder(transition):
      return transition._replace(
          s_tm1=replay_lib.uncompress_array(transition.s_tm1),
          s_t=replay_lib.uncompress_array(transition.s_t))
  else:
    encoder = None
    decoder = None

  replay_structure = replay_lib.Transition(
      s_tm1=None,
      a_tm1=None,
      r_t=None,
      discount_t=None,
      s_t=None,
  )

  replay = replay_lib.PrioritizedTransitionReplay(
      FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent,
      importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability,
      FLAGS.normalize_weights, random_state, encoder, decoder)

  optimizer = optax.rmsprop(
      learning_rate=FLAGS.learning_rate,
      decay=0.95,
      eps=FLAGS.optimizer_epsilon,
      centered=True,
  )

  train_rng_key, eval_rng_key = jax.random.split(rng_key)

  train_agent = agent.PrioritizedDqn(
      preprocessor=preprocessor_builder(),
      sample_network_input=sample_network_input,
      network=network,
      optimizer=optimizer,
      transition_accumulator=replay_lib.TransitionAccumulator(),
      replay=replay,
      batch_size=FLAGS.batch_size,
      exploration_epsilon=exploration_epsilon_schedule,
      min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
      learn_period=FLAGS.learn_period,
      target_network_update_period=FLAGS.target_network_update_period,
      grad_error_bound=FLAGS.grad_error_bound,
      rng_key=train_rng_key,
  )
  eval_agent = parts.EpsilonGreedyActor(
      preprocessor=preprocessor_builder(),
      network=network,
      exploration_epsilon=FLAGS.eval_exploration_epsilon,
      rng_key=eval_rng_key,
  )

  # Set up checkpointing.
  checkpoint = parts.NullCheckpoint()

  state = checkpoint.state
  state.iteration = 0
  state.train_agent = train_agent
  state.eval_agent = eval_agent
  state.random_state = random_state
  state.writer = writer
  if checkpoint.can_be_restored():
    checkpoint.restore()

  while state.iteration <= FLAGS.num_iterations:
    # New environment for each iteration to allow for determinism if preempted.
    env = environment_builder()

    logging.info('Training iteration %d.', state.iteration)
    train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode)
    num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
    train_seq_truncated = itertools.islice(train_seq, num_train_frames)
    train_trackers = parts.make_default_trackers(train_agent)
    train_stats = parts.generate_statistics(train_trackers, train_seq_truncated)

    logging.info('Evaluation iteration %d.', state.iteration)
    eval_agent.network_params = train_agent.online_params
    eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode)
    eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
    eval_trackers = parts.make_default_trackers(eval_agent)
    eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated)

    # Logging and checkpointing.
    human_normalized_score = atari_data.get_human_normalized_score(
        FLAGS.environment_name, eval_stats['episode_return'])
    capped_human_normalized_score = np.amin([1., human_normalized_score])
    log_output = [
        ('iteration', state.iteration, '%3d'),
        ('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
        ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'),
        ('train_episode_return', train_stats['episode_return'], '% 2.2f'),
        ('eval_num_episodes', eval_stats['num_episodes'], '%3d'),
        ('train_num_episodes', train_stats['num_episodes'], '%3d'),
        ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'),
        ('train_frame_rate', train_stats['step_rate'], '%4.0f'),
        ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'),
        ('train_state_value', train_stats['state_value'], '%.3f'),
        ('importance_sampling_exponent',
         train_agent.importance_sampling_exponent, '%.3f'),
        ('max_seen_priority', train_agent.max_seen_priority, '%.3f'),
        ('normalized_return', human_normalized_score, '%.3f'),
        ('capped_normalized_return', capped_human_normalized_score, '%.3f'),
        ('human_gap', 1. - capped_human_normalized_score, '%.3f'),
    ]
    log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
    logging.info(log_output_str)
    writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
    state.iteration += 1
    checkpoint.save()

  writer.close()
コード例 #27
0
    def fit_full(self, config: MLPTrainingConfig) -> MLP:
        if config.best_epoch is None:
            raise ValueError("best epoch not specified by MLP Config")
        rng_key = jax.random.PRNGKey(0)

        mlp_function = hk.transform(lambda x, training: (create_mlp(
            self.embedding_train.shape[1],
            config,
        ))(x, training))

        X = sps.vstack([self.profile_train, self.profile_test])
        y = jnp.concatenate([self.embedding_train, self.embedding_test],
                            axis=0)
        mb_size = 128

        rng_key, sub_key = jax.random.split(rng_key)
        params = mlp_function.init(
            sub_key,
            jnp.zeros((1, self.profile_train.shape[1]), dtype=jnp.float32),
            True,
        )
        opt = optax.adam(config.learning_rate)
        opt_state = opt.init(params)

        @partial(jax.jit, static_argnums=(3, ))
        def predict(params: hk.Params, rng: PRNGKey, X: jnp.ndarray,
                    training: bool) -> jnp.ndarray:
            return mlp_function.apply(params, rng, X, training)

        @partial(jax.jit, static_argnums=(4, ))
        def loss_fn(
            params: hk.Params,
            rng: PRNGKey,
            X: jnp.ndarray,
            Y: jnp.ndarray,
            training: bool,
        ) -> jnp.ndarray:
            prediction = predict(params, rng, X, training)
            return ((Y - prediction)**2).mean(axis=1).sum()

        @jax.jit
        def update(
            params: hk.Params,
            rng: PRNGKey,
            opt_state: optax.OptState,
            X: jnp.ndarray,
            Y: jnp.ndarray,
        ) -> Tuple[jnp.ndarray, hk.Params, optax.OptState]:
            loss_value = loss_fn(params, rng, X, Y, True)
            grad = jax.grad(loss_fn)(params, rng, X, Y, True)
            updates, opt_state = opt.update(grad, opt_state)
            new_params = optax.apply_updates(params, updates)
            return loss_value, new_params, opt_state

        mb_size = 128
        for _ in tqdm(range(config.best_epoch)):
            train_loss = 0
            for X_mb, y_mb, _ in self.stream(X, y, mb_size):
                rng_key, sub_key = jax.random.split(rng_key)
                loss_value, params, opt_state = update(params, sub_key,
                                                       opt_state, X_mb, y_mb)
                train_loss += loss_value
            train_loss /= self.profile_train.shape[0]
        return MLP(predict, params)
コード例 #28
0
  def __init__(
      self,
      obs_spec: specs.Array,
      unroll_fn: networks_lib.PolicyValueRNN,
      initial_state_fn: Callable[[], hk.LSTMState],
      iterator: Iterator[reverb.ReplaySample],
      optimizer: optax.GradientTransformation,
      random_key: networks_lib.PRNGKey,
      discount: float = 0.99,
      entropy_cost: float = 0.,
      baseline_cost: float = 1.,
      max_abs_reward: float = np.inf,
      counter: counting.Counter = None,
      logger: loggers.Logger = None,
      devices: Optional[Sequence[jax.xla.Device]] = None,
      prefetch_size: int = 2,
      num_prefetch_threads: Optional[int] = None,
  ):

    self._devices = devices or jax.local_devices()

    # Transform into pure functions.
    unroll_fn = hk.without_apply_rng(hk.transform(unroll_fn, apply_rng=True))
    initial_state_fn = hk.without_apply_rng(
        hk.transform(initial_state_fn, apply_rng=True))

    loss_fn = losses.impala_loss(
        unroll_fn,
        discount=discount,
        max_abs_reward=max_abs_reward,
        baseline_cost=baseline_cost,
        entropy_cost=entropy_cost)

    @jax.jit
    def sgd_step(
        state: TrainingState, sample: reverb.ReplaySample
    ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]:
      """Computes an SGD step, returning new state and metrics for logging."""

      # Compute gradients.
      grad_fn = jax.value_and_grad(loss_fn)
      loss_value, gradients = grad_fn(state.params, sample)

      # Average gradients over pmap replicas before optimizer update.
      gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME)

      # Apply updates.
      updates, new_opt_state = optimizer.update(gradients, state.opt_state)
      new_params = optax.apply_updates(state.params, updates)

      metrics = {
          'loss': loss_value,
      }

      new_state = TrainingState(params=new_params, opt_state=new_opt_state)

      return new_state, metrics

    def make_initial_state(key: jnp.ndarray) -> TrainingState:
      """Initialises the training state (parameters and optimiser state)."""
      dummy_obs = utils.zeros_like(obs_spec)
      dummy_obs = utils.add_batch_dim(dummy_obs)  # Dummy 'sequence' dim.
      initial_state = initial_state_fn.apply(None)
      initial_params = unroll_fn.init(key, dummy_obs, initial_state)
      initial_opt_state = optimizer.init(initial_params)
      return TrainingState(params=initial_params, opt_state=initial_opt_state)

    # Initialise training state (parameters and optimiser state).
    state = make_initial_state(random_key)
    self._state = utils.replicate_in_all_devices(state, self._devices)

    if num_prefetch_threads is None:
      num_prefetch_threads = len(self._devices)
    self._prefetched_iterator = utils.sharded_prefetch(
        iterator,
        buffer_size=prefetch_size,
        devices=devices,
        num_threads=num_prefetch_threads,
    )

    self._sgd_step = jax.pmap(
        sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices)

    # Set up logging/counting.
    self._counter = counter or counting.Counter()
    self._logger = logger or loggers.make_default_logger('learner')
コード例 #29
0
    def __init__(self,
                 network: networks.QNetwork,
                 obs_spec: specs.Array,
                 discount: float,
                 importance_sampling_exponent: float,
                 target_update_period: int,
                 iterator: Iterator[reverb.ReplaySample],
                 optimizer: optix.InitUpdate,
                 rng: hk.PRNGSequence,
                 max_abs_reward: float = 1.,
                 huber_loss_parameter: float = 1.,
                 replay_client: reverb.Client = None,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None):
        """Initializes the learner."""

        # Transform network into a pure function.
        network = hk.transform(network)

        def loss(params: hk.Params, target_params: hk.Params,
                 sample: reverb.ReplaySample):
            o_tm1, a_tm1, r_t, d_t, o_t = sample.data
            keys, probs = sample.info[:2]

            # Forward pass.
            q_tm1 = network.apply(params, o_tm1)
            q_t_value = network.apply(target_params, o_t)
            q_t_selector = network.apply(params, o_t)

            # Cast and clip rewards.
            d_t = (d_t * discount).astype(jnp.float32)
            r_t = jnp.clip(r_t, -max_abs_reward,
                           max_abs_reward).astype(jnp.float32)

            # Compute double Q-learning n-step TD-error.
            batch_error = jax.vmap(rlax.double_q_learning)
            td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value,
                                   q_t_selector)
            batch_loss = rlax.huber_loss(td_error, huber_loss_parameter)

            # Importance weighting.
            importance_weights = (1. / probs).astype(jnp.float32)
            importance_weights **= importance_sampling_exponent
            importance_weights /= jnp.max(importance_weights)

            # Reweight.
            mean_loss = jnp.mean(importance_weights * batch_loss)  # []

            priorities = jnp.abs(td_error).astype(jnp.float64)

            return mean_loss, (keys, priorities)

        def sgd_step(
            state: TrainingState, samples: reverb.ReplaySample
        ) -> Tuple[TrainingState, LearnerOutputs]:
            grad_fn = jax.grad(loss, has_aux=True)
            gradients, (keys, priorities) = grad_fn(state.params,
                                                    state.target_params,
                                                    samples)
            updates, new_opt_state = optimizer.update(gradients,
                                                      state.opt_state)
            new_params = optix.apply_updates(state.params, updates)

            new_state = TrainingState(params=new_params,
                                      target_params=state.target_params,
                                      opt_state=new_opt_state,
                                      step=state.step + 1)

            outputs = LearnerOutputs(keys=keys, priorities=priorities)

            return new_state, outputs

        def update_priorities(outputs: LearnerOutputs):
            for key, priority in zip(outputs.keys, outputs.priorities):
                replay_client.mutate_priorities(
                    table=adders.DEFAULT_PRIORITY_TABLE,
                    updates={key: priority})

        # Internalise agent components (replay buffer, networks, optimizer).
        self._replay_client = replay_client
        self._iterator = utils.prefetch(iterator)

        # Internalise the hyperparameters.
        self._target_update_period = target_update_period

        # Internalise logging/counting objects.
        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        # Initialise parameters and optimiser state.
        initial_params = network.init(
            next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
        initial_target_params = network.init(
            next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec)))
        initial_opt_state = optimizer.init(initial_params)

        self._state = TrainingState(params=initial_params,
                                    target_params=initial_target_params,
                                    opt_state=initial_opt_state,
                                    step=0)

        self._forward = jax.jit(network.apply)
        self._sgd_step = jax.jit(sgd_step)
        self._async_priority_updater = async_utils.AsyncExecutor(
            update_priorities)
コード例 #30
0
ファイル: run_key_door.py プロジェクト: seblee97/dqn_zoo
def main(argv):
    """Trains DQN agent on Atari."""
    del argv
    logging.info("DQN on Atari on %s.",
                 jax.lib.xla_bridge.get_backend().platform)
    random_state = np.random.RandomState(FLAGS.seed)
    rng_key = jax.random.PRNGKey(
        random_state.randint(-sys.maxsize - 1, sys.maxsize + 1))

    if FLAGS.results_csv_path:
        writer = parts.CsvWriter(FLAGS.results_csv_path)
    else:
        writer = parts.NullWriter()

    def environment_builder():
        """Creates Key-Door environment."""
        env = gym_key_door.GymKeyDoor(
            env_args={
                constants.MAP_ASCII_PATH: FLAGS.map_ascii_path,
                constants.MAP_YAML_PATH: FLAGS.map_yaml_path,
                constants.REPRESENTATION: constants.PIXEL,
                constants.SCALING: FLAGS.env_scaling,
                constants.EPISODE_TIMEOUT: FLAGS.max_frames_per_episode,
                constants.GRAYSCALE: False,
                constants.BATCH_DIMENSION: False,
                constants.TORCH_AXES: False,
            },
            env_shape=FLAGS.env_shape,
        )
        return gym_atari.RandomNoopsEnvironmentWrapper(
            env,
            min_noop_steps=1,
            max_noop_steps=30,
            seed=random_state.randint(1, 2**32),
        )

    env = environment_builder()

    logging.info("Environment: %s", FLAGS.environment_name)
    logging.info("Action spec: %s", env.action_spec())
    logging.info("Observation spec: %s", env.observation_spec())
    num_actions = env.action_spec().num_values
    network_fn = networks.dqn_atari_network(num_actions)
    network = hk.transform(network_fn)

    def preprocessor_builder():
        return processors.atari(
            additional_discount=FLAGS.additional_discount,
            max_abs_reward=FLAGS.max_abs_reward,
            resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
            num_action_repeats=FLAGS.num_action_repeats,
            num_pooled_frames=2,
            zero_discount_on_life_loss=True,
            num_stacked_frames=FLAGS.num_stacked_frames,
            grayscaling=True,
        )

    # Create sample network input from sample preprocessor output.
    sample_processed_timestep = preprocessor_builder()(env.reset())
    sample_processed_timestep = typing.cast(dm_env.TimeStep,
                                            sample_processed_timestep)
    sample_network_input = sample_processed_timestep.observation
    assert sample_network_input.shape == (
        FLAGS.environment_height,
        FLAGS.environment_width,
        FLAGS.num_stacked_frames,
    )

    exploration_epsilon_schedule = parts.LinearSchedule(
        begin_t=int(FLAGS.min_replay_capacity_fraction *
                    FLAGS.replay_capacity * FLAGS.num_action_repeats),
        decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
                        FLAGS.num_iterations * FLAGS.num_train_frames),
        begin_value=FLAGS.exploration_epsilon_begin_value,
        end_value=FLAGS.exploration_epsilon_end_value,
    )

    if FLAGS.compress_state:

        def encoder(transition):
            return transition._replace(
                s_tm1=replay_lib.compress_array(transition.s_tm1),
                s_t=replay_lib.compress_array(transition.s_t),
            )

        def decoder(transition):
            return transition._replace(
                s_tm1=replay_lib.uncompress_array(transition.s_tm1),
                s_t=replay_lib.uncompress_array(transition.s_t),
            )

    else:
        encoder = None
        decoder = None

    replay_structure = replay_lib.Transition(
        s_tm1=None,
        a_tm1=None,
        r_t=None,
        discount_t=None,
        s_t=None,
    )

    replay = replay_lib.TransitionReplay(FLAGS.replay_capacity,
                                         replay_structure, random_state,
                                         encoder, decoder)

    optimizer = optax.rmsprop(
        learning_rate=FLAGS.learning_rate,
        decay=0.95,
        eps=FLAGS.optimizer_epsilon,
        centered=True,
    )

    if FLAGS.shaping_function_type == constants.NO_PENALTY:
        shaping_function = shaping.NoPenalty()
    if FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY:
        shaping_function = shaping.HardCodedPenalty(
            penalty=FLAGS.shaping_multiplicative_factor)

    train_rng_key, eval_rng_key = jax.random.split(rng_key)

    train_agent = agent.Dqn(
        preprocessor=preprocessor_builder(),
        sample_network_input=sample_network_input,
        network=network,
        optimizer=optimizer,
        transition_accumulator=replay_lib.TransitionAccumulator(),
        replay=replay,
        shaping_function=shaping_function,
        batch_size=FLAGS.batch_size,
        exploration_epsilon=exploration_epsilon_schedule,
        min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
        learn_period=FLAGS.learn_period,
        target_network_update_period=FLAGS.target_network_update_period,
        grad_error_bound=FLAGS.grad_error_bound,
        rng_key=train_rng_key,
    )
    eval_agent = parts.EpsilonGreedyActor(
        preprocessor=preprocessor_builder(),
        network=network,
        exploration_epsilon=FLAGS.eval_exploration_epsilon,
        rng_key=eval_rng_key,
    )

    # Set up checkpointing.
    # checkpoint = parts.NullCheckpoint()
    checkpoint = parts.ImplementedCheckpoint(
        checkpoint_path=FLAGS.checkpoint_path)

    if checkpoint.can_be_restored():
        checkpoint.restore()
        train_agent.set_state(state=checkpoint.state.train_agent)
        eval_agent.set_state(state=checkpoint.state.eval_agent)
        writer.set_state(state=checkpoint.state.writer)

    state = checkpoint.state
    state.iteration = 0
    state.train_agent = train_agent.get_state()
    state.eval_agent = eval_agent.get_state()
    state.random_state = random_state
    state.writer = writer.get_state()

    while state.iteration <= FLAGS.num_iterations:
        # New environment for each iteration to allow for determinism if preempted.
        env = environment_builder()

        logging.info("Training iteration %d.", state.iteration)
        train_seq = parts.run_loop(train_agent, env,
                                   FLAGS.max_frames_per_episode)
        num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
        train_seq_truncated = itertools.islice(train_seq, num_train_frames)
        train_stats = parts.generate_statistics(train_seq_truncated)

        logging.info("Evaluation iteration %d.", state.iteration)
        eval_agent.network_params = train_agent.online_params
        eval_seq = parts.run_loop(eval_agent, env,
                                  FLAGS.max_frames_per_episode)
        eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames)
        eval_stats = parts.generate_statistics(eval_seq_truncated)

        # Logging and checkpointing.
        human_normalized_score = atari_data.get_human_normalized_score(
            FLAGS.environment_name, eval_stats["episode_return"])
        capped_human_normalized_score = np.amin([1.0, human_normalized_score])
        log_output = [
            ("iteration", state.iteration, "%3d"),
            ("frame", state.iteration * FLAGS.num_train_frames, "%5d"),
            ("eval_episode_return", eval_stats["episode_return"], "% 2.2f"),
            ("train_episode_return", train_stats["episode_return"], "% 2.2f"),
            ("eval_num_episodes", eval_stats["num_episodes"], "%3d"),
            ("train_num_episodes", train_stats["num_episodes"], "%3d"),
            ("eval_frame_rate", eval_stats["step_rate"], "%4.0f"),
            ("train_frame_rate", train_stats["step_rate"], "%4.0f"),
            ("train_exploration_epsilon", train_agent.exploration_epsilon,
             "%.3f"),
            ("normalized_return", human_normalized_score, "%.3f"),
            ("capped_normalized_return", capped_human_normalized_score,
             "%.3f"),
            ("human_gap", 1.0 - capped_human_normalized_score, "%.3f"),
        ]
        log_output_str = ", ".join(
            ("%s: " + f) % (n, v) for n, v, f in log_output)
        logging.info(log_output_str)
        writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
        state.iteration += 1
        checkpoint.save()

    writer.close()