def create_logistic_model( only_digits: bool = False, reg_fn: Optional[Callable[[core.Params], jnp.ndarray]] = None) -> core.Model: """Creates EMNIST logistic model.""" num_classes = 10 if only_digits else 62 def forward_pass(batch): network = hk.Sequential([ hk.Flatten(), hk.Linear(num_classes), ]) return network(batch['x']) transformed_forward_pass = hk.transform(forward_pass) return core.create_model_from_haiku( transformed_forward_pass=transformed_forward_pass, sample_batch=_EMNIST_HAIKU_SAMPLE_BATCH, loss_fn=_EMNIST_LOSS_FN, reg_fn=reg_fn, metrics_fn_map=_EMNIST_METRICS_FN_MAP)
def test_linear_ibp(self): def linear_model(inp): return hk.Linear(1)(inp) z = jnp.array([[1., 2., 3.]]) params = { 'linear': { 'w': jnp.ones((3, 1), dtype=jnp.float32), 'b': jnp.array([2.]) } } fun = functools.partial( hk.without_apply_rng(hk.transform(linear_model, apply_rng=True)).apply, params) input_bounds = jax_verify.IntervalBound(z - 1., z + 1.) output_bounds = jax_verify.interval_bound_propagation( fun, input_bounds) self.assertAlmostEqual(5., output_bounds.lower) self.assertAlmostEqual(11., output_bounds.upper)
def test_bijector_that_assumes_batch_dimensions(self): # Create a Haiku conditioner that assumes a single batch dimension. def forward(x): network = hk.Sequential([hk.Flatten(preserve_dims=1), hk.Linear(3)]) return network(x) init, apply = hk.transform(forward) params = init(self.seed, jnp.ones((2, 3))) conditioner = functools.partial(apply, params, self.seed) bijector = masked_coupling.MaskedCoupling( jnp.ones(3) > 0, conditioner, tfb.Scale) base = tfd.MultivariateNormalDiag(jnp.zeros((2, 3)), jnp.ones((2, 3))) dist = transformed.Transformed(base, bijector) # Exercise the trace-based functions assert dist.batch_shape == (2,) assert dist.event_shape == (3,) assert dist.dtype == jnp.float32 sample = self.variant(dist.sample)(seed=self.seed) assert sample.dtype == dist.dtype self.variant(dist.log_prob)(sample)
def test_torch_to_jax(): x = np.random.randn(250, 66).astype(np.float32) net_torch = FlexibleNeRFModelTorch() net_jax = hk.without_apply_rng( hk.transform(jax.jit(lambda x: FlexibleNeRFModel()(x)))) jax_params = torch_to_jax(dict(net_torch.named_parameters()), "flexible_ne_rf_model") jax_out = net_jax.apply(jax_params, jnp.array(x)) torch_out = net_torch(torch.from_numpy(x)) assert np.allclose(torch_out.detach().numpy(), np.array(jax_out), atol=1e-7) # now let's verify that the gradients are correct jax_fn = lambda x, p: net_jax.apply(p, x).flatten().sum() jax_params_grad = jit(grad(jax_fn, argnums=(1, )))(jnp.array(x), jax_params)[0] torch_loss = torch_out.flatten().sum() torch_loss.backward() torch_grads = torch_to_jax( {k: v.grad for k, v in net_torch.named_parameters()}, "flexible_ne_rf_model") def recursive_compare(d1, d2): assert (d1.keys() == d2.keys()) for key in d1.keys(): if isinstance(d1[key], dict): assert isinstance(d2[key], dict) recursive_compare(d1[key], d2[key]) else: assert np.allclose(d1[key], d2[key], rtol=1e-3, atol=1e-7) recursive_compare(jax_params_grad, torch_grads)
def create_dense_model(only_digits: bool = False, hidden_units: int = 200) -> models.Model: """Creates EMNIST dense net with haiku.""" num_classes = 10 if only_digits else 62 def forward_pass(batch): network = hk.Sequential([ hk.Flatten(), hk.Linear(hidden_units), jax.nn.relu, hk.Linear(hidden_units), jax.nn.relu, hk.Linear(num_classes), ]) return network(batch['x']) transformed_forward_pass = hk.transform(forward_pass) return models.create_model_from_haiku( transformed_forward_pass=transformed_forward_pass, sample_batch=_HAIKU_SAMPLE_BATCH, train_loss=_TRAIN_LOSS, eval_metrics=_EVAL_METRICS)
def test_unvectorize_single_output(rngs, x_batch, x_single): def f_batch(X): return hk.Linear(11)(X) init, f_batch = hk.transform(f_batch) params = init(next(rngs), x_batch) y_batch = f_batch(params, next(rngs), x_batch) assert y_batch.shape == (7, 11) f_single = unvectorize(f_batch, in_axes=(None, None, 0), out_axes=0) y_single = f_single(params, next(rngs), x_single) assert y_single.shape == (11, ) f_single = unvectorize(f_batch, in_axes=(None, None, 0), out_axes=(0, )) msg = r"out_axes must be an int for functions with a single output; got: out_axes=\(0,\)" with pytest.raises(TypeError, match=msg): f_single(params, next(rngs), x_single) f_single = unvectorize(f_batch, in_axes=(None, None, 0, 0), out_axes=(0, )) msg = r"number of in_axes must match the number of function inputs" with pytest.raises(ValueError, match=msg): f_single(params, next(rngs), x_single)
def make_policy_prior_network( spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (64, 64) ) -> networks.FeedForwardNetwork: """Creates a policy prior network used by the agent.""" action_size = np.prod(spec.actions.shape, dtype=int) def _policy_prior_fn(observation_t, action_tm1, is_training=False, key=None): # is_training and key allows to defined train/test dependant modules # like dropout. del is_training del key network = hk.nets.MLP(hidden_layer_sizes + (action_size, )) # Policy prior returns an action. return network(jnp.concatenate([observation_t, action_tm1], axis=-1)) policy_prior = hk.without_apply_rng(hk.transform(_policy_prior_fn)) return make_network_from_module(policy_prior, spec)
def test_summarize_model(self): def model_fun(x): """A model with two submodules.""" class Alpha(hk.Module): # Alpha submodule. def __call__(self, x): return hk.Sequential([ hk.Conv2D(8, (3, 3)), jax.nn.relu, hk.MaxPool((1, 2, 2, 1), (1, 2, 2, 1), 'VALID'), hk.Flatten(), hk.Linear(3, with_bias=False) ])(x) class Beta(hk.Module): # Beta submodule. def __call__(self, x): return hk.Sequential([hk.Flatten(), hk.Linear(3), jax.nn.relu])(x) return hk.Linear(1)(Alpha()(x) + Beta()(x)) model = hk.transform(model_fun) x = np.random.randn(1, 12, 15, 1) params = model.init(jax.random.PRNGKey(0), x) summary = hk_util.summarize_model(params) self.assertEqual( summary, """ Variable Shape # alpha/conv2_d.b (8,) 8 alpha/conv2_d.w (3, 3, 1, 8) 72 alpha/linear.w (336, 3) 1008 beta/linear.b (3,) 3 beta/linear.w (180, 3) 540 linear.b (1,) 1 linear.w (3, 1) 3 Total 1635 """.strip())
def setUp(self): super().setUp() self.data = np.random.rand(NUM_SAMPLES, NUM_FEATURES) self.labels = np.random.randint(NUM_CLASSES, size=NUM_SAMPLES) def net_fn(z): mlp = hk.Sequential( [hk.Linear(10), jax.nn.relu, hk.Linear(NUM_CLASSES)], name='mlp') return jax.nn.log_softmax(mlp(z)) net = hk.without_apply_rng(hk.transform(net_fn, apply_rng=True)) self.parameters = net.init(jax.random.PRNGKey(0), self.data) def loss(params, inputs, targets): log_probs = net.apply(params, inputs) return -jnp.mean(hk.one_hot(targets, NUM_CLASSES) * log_probs) self.loss_fn = loss def jax_hessian_diag(loss_fun, params, inputs, targets): """This is the 'ground-truth' obtained via the JAX library.""" hess = jax.hessian(loss_fun)(params, inputs, targets) # Extracts the diagonal components. hess_diag = collections.defaultdict(dict) for k0, k1 in itertools.product(params.keys(), ['w', 'b']): params_shape = params[k0][k1].shape n_params = np.prod(params_shape) hess_diag[k0][k1] = jnp.diag(hess[k0][k1][k0][k1].reshape( n_params, n_params)).reshape(params_shape) for k, v in hess_diag.items(): hess_diag[k] = v return second_order.ravel(hess_diag) self.hessian = jax_hessian_diag(self.loss_fn, self.parameters, self.data, self.labels)
def __init__( self, forward_fn: PolicyValueFn, initial_state_fn: Callable[[], hk.LSTMState], rng: hk.PRNGSequence, variable_client: Optional[variable_utils.VariableClient] = None, adder: Optional[adders.Adder] = None, ): # Store these for later use. self._adder = adder self._variable_client = variable_client self._forward = forward_fn self._rng = rng # Make sure not to use a random policy after checkpoint restoration by # assigning variables before running the environment loop. if self._variable_client is not None: self._variable_client.update_and_wait() self._initial_state = hk.without_apply_rng( hk.transform(initial_state_fn, apply_rng=True)).apply(None)
def getNetwork(inputs: Tuple, outputs: int, params: Dict[str, Any], seed: int): name = params['type'] if name == 'TwoLayerRelu': hidden = params['hidden'] layers = [hidden, hidden] network = partial(nn, layers, outputs) elif name == 'OneLayerRelu': hidden = params['hidden'] layers = [hidden] network = partial(nn, layers, outputs) elif name == 'MinatarNet': def conv(x): hidden = hk.Sequential([ hk.Conv2D(16, 3, 2), jax.nn.relu, hk.Flatten(), ]) return hidden(x) hidden = params['hidden'] layers = [hidden] network = pipe([conv, partial(nn, layers, outputs)]) else: raise NotImplementedError() network = hk.without_apply_rng(hk.transform(network)) net_params = network.init(jax.random.PRNGKey(seed), jnp.zeros((1, ) + tuple(inputs))) return network, net_params
def test_feedforward(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): action_values = hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), ])(inputs) action = jnp.argmax(action_values, axis=-1) if has_extras: return action, (action_values, ) else: return action policy = hk.transform(policy) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like( env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') if has_extras: actor_core = actor_core_lib.batched_feed_forward_with_extras_to_actor_core( policy.apply) else: actor_core = actor_core_lib.batched_feed_forward_to_actor_core( policy.apply) actor = actors.GenericActor(actor_core, random_key=jax.random.PRNGKey(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def main(_): FLAGS.alsologtostderr = True # Always log visibly. # Create the dataset. train_dataset = dataset.AsciiDataset(FLAGS.dataset_path, FLAGS.batch_size, FLAGS.sequence_length) vocab_size = train_dataset.vocab_size # Set up the model, loss, and updater. forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads, FLAGS.num_layers, FLAGS.dropout_rate) forward_fn = hk.transform(forward_fn) loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size) optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.grad_clip_value), optax.adam(FLAGS.learning_rate, b1=0.9, b2=0.99)) updater = Updater(forward_fn.init, loss_fn, optimizer) updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir) # Initialize parameters. logging.info('Initializing parameters...') rng = jax.random.PRNGKey(428) data = next(train_dataset) state = updater.init(rng, data) logging.info('Starting train loop...') prev_time = time.time() for step in range(MAX_STEPS): data = next(train_dataset) state, metrics = updater.update(state, data) # We use JAX runahead to mask data preprocessing and JAX dispatch overheads. # Using values from state/metrics too often will block the runahead and can # cause these overheads to become more prominent. if step % LOG_EVERY == 0: steps_per_sec = LOG_EVERY / (time.time() - prev_time) prev_time = time.time() metrics.update({'steps_per_sec': steps_per_sec}) logging.info({k: float(v) for k, v in metrics.items()})
def make_world_model_network( spec: specs.EnvironmentSpec, hidden_layer_sizes: Tuple[int, ...] = (64, 64) ) -> networks.FeedForwardNetwork: """Creates a world model network used by the agent.""" observation_size = np.prod(spec.observations.shape, dtype=int) def _world_model_fn(observation_t, action_t, is_training=False, key=None): # is_training and key allows to defined train/test dependant modules # like dropout. del is_training del key network = hk.nets.MLP(hidden_layer_sizes + (observation_size + 1, )) # World model returns both an observation and a reward. observation_tp1, reward_t = jnp.split(network( jnp.concatenate([observation_t, action_t], axis=-1)), [observation_size], axis=-1) return observation_tp1, reward_t world_model = hk.without_apply_rng(hk.transform(_world_model_fn)) return make_network_from_module(world_model, spec)
def test_model_workflow(self): meta = FooMetadata(hidden_units=[5, 2]) model = hk.transform(functools.partial(foo_model, meta=meta)) # Get some random param values. batch = {'x': jnp.array([[0.5, 1.0, -1.5]])} params = model.init(jax.random.PRNGKey(0), batch) # Associate params with the model to get a TrainedModel. trained_model = hk_util.TrainedModel(model, meta=meta, params=params) # Save and load the model. filename = '/tmp/hk_util_test/model.pkl' trained_model.save(filename) recovered = hk_util.TrainedModel.load(filename, foo_model, FooMetadata) # Check that meta, params, and model forward function are the same. self.assertEqual(recovered.meta, meta) self._assert_tree_equal(recovered.params, params) y = recovered(batch) expected_y = model.apply(params, batch) np.testing.assert_array_equal(y, expected_y)
def make_network( spec: specs.EnvironmentSpec) -> networks_lib.FeedForwardNetwork: """Creates networks used by the agent.""" num_actions = spec.actions.num_values def actor_fn(obs, is_training=True, key=None): # is_training and key allows to utilize train/test dependant modules # like dropout. del is_training del key mlp = hk.Sequential([hk.Flatten(), hk.nets.MLP([64, 64, num_actions])]) return mlp(obs) policy = hk.without_apply_rng(hk.transform(actor_fn)) # Create dummy observations to create network parameters. dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) network = networks_lib.FeedForwardNetwork( lambda key: policy.init(key, dummy_obs), policy.apply) return network
def default_agent( obs_spec: specs.Array, action_spec: specs.DiscreteArray, seed: int = 0, num_ensemble: int = 20, ) -> BootstrappedDqn: """Initialize a Bootstrapped DQN agent with default parameters.""" # Define network. prior_scale = 3. hidden_sizes = [50, 50] def network(inputs: jnp.ndarray) -> jnp.ndarray: """Simple Q-network with randomized prior function.""" net = hk.nets.MLP([*hidden_sizes, action_spec.num_values]) prior_net = hk.nets.MLP([*hidden_sizes, action_spec.num_values]) x = hk.Flatten()(inputs) return net(x) + prior_scale * lax.stop_gradient(prior_net(x)) optimizer = optix.adam(learning_rate=1e-3) return BootstrappedDqn( obs_spec=obs_spec, action_spec=action_spec, network=hk.transform(network), batch_size=128, discount=.99, num_ensemble=num_ensemble, replay_capacity=10000, min_replay_size=128, sgd_period=1, target_update_period=4, optimizer=optimizer, mask_prob=0.5, noise_scale=0., epsilon_fn=lambda _: 0., seed=seed, )
def test_run_network(): x = np.random.randn(250, 66).astype(np.float32) net_torch = FlexibleNeRFModelTorch() net_jax = hk.without_apply_rng( hk.transform(jax.jit(lambda x: FlexibleNeRFModel()(x)))) jax_params = torch_to_jax(dict(net_torch.named_parameters()), "flexible_ne_rf_model") jax_out = net_jax.apply(jax_params, jnp.array(x)) torch_out = net_torch(torch.from_numpy(x)) pts_np = np.random.random((256, 128, 3)).astype(np.float32) ray_batch_np = np.random.random((256, 11)).astype(np.float32) pts_torch = torch.from_numpy(pts_np) ray_batch_torch = torch.from_numpy(ray_batch_np) pts_jax = jnp.array(pts_np) ray_batch_jax = jnp.array(ray_batch_np) torch_result = run_network_torch( net_torch, pts_torch, ray_batch_torch, 32, lambda p: positional_encoding_torch(p, 6), lambda p: positional_encoding_torch(p, 4), ) jax_result = run_network(functools.partial(net_jax.apply, jax_params), pts_jax, ray_batch_jax, 32, 6, 4) assert np.allclose(np.array(jax_result), torch_result.detach().numpy(), atol=1e-7) '''jax_fn = (
def test_random_module_mcmc(backend): if backend == "flax": import flax linear_module = flax.nn.Dense.partial(features=1) bias_name = "bias" weight_name = "kernel" random_module = random_flax_module elif backend == "haiku": import haiku as hk linear_module = hk.transform(lambda x: hk.Linear(1)(x)) bias_name = "linear.b" weight_name = "linear.w" random_module = random_haiku_module def model(data, labels): nn = random_module("nn", linear_module, prior={bias_name: dist.Cauchy(), weight_name: dist.Normal()}, input_shape=(dim,)) logits = nn(data).squeeze(-1) numpyro.sample("y", dist.Bernoulli(logits=logits), obs=labels) N, dim = 3000, 3 warmup_steps, num_samples = (1000, 1000) data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = np.arange(1., dim + 1.) logits = np.sum(true_coefs * data, axis=-1) labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) kernel = NUTS(model=model) mcmc = MCMC(kernel, warmup_steps, num_samples, progress_bar=False) mcmc.run(random.PRNGKey(2), data, labels) mcmc.print_summary() samples = mcmc.get_samples() assert set(samples.keys()) == {"nn/{}".format(bias_name), "nn/{}".format(weight_name)} assert_allclose(np.mean(samples["nn/{}".format(weight_name)].squeeze(-1), 0), true_coefs, atol=0.22)
def make_continuous_networks( environment_spec: specs.EnvironmentSpec, policy_layer_sizes: Sequence[int] = (64, 64), value_layer_sizes: Sequence[int] = (64, 64), ) -> PPONetworks: """Creates PPONetworks to be used for continuous action environments.""" # Get total number of action dimensions from action spec. num_dimensions = np.prod(environment_spec.actions.shape, dtype=int) def forward_fn(inputs): policy_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP(policy_layer_sizes, activation=jnp.tanh), # Note: we don't respect bounded action specs here and instead # rely on CanonicalSpecWrapper to clip actions accordingly. networks_lib.MultivariateNormalDiagHead(num_dimensions) ]) value_network = hk.Sequential([ utils.batch_concat, hk.nets.MLP(value_layer_sizes, activation=jnp.tanh), hk.Linear(1), lambda x: jnp.squeeze(x, axis=-1) ]) action_distribution = policy_network(inputs) value = value_network(inputs) return (action_distribution, value) # Transform into pure functions. forward_fn = hk.without_apply_rng(hk.transform(forward_fn)) dummy_obs = utils.zeros_like(environment_spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. network = networks_lib.FeedForwardNetwork( lambda rng: forward_fn.init(rng, dummy_obs), forward_fn.apply) # Create PPONetworks to add functionality required by the agent. return make_ppo_networks(network)
def test_supported_loss_returns_correctly_no_loss_kwargs(self): import haiku as hk def net_function(x: jnp.ndarray) -> jnp.ndarray: net = hk.Sequential([]) return net(x) net_transform = hk.transform(net_function) actual_loss_function_wrapper = get_haiku_loss_function( net_transform, loss="mean_squared_error") # Check works rng = jax.random.PRNGKey(42) params = net_transform.init(rng, jnp.array(0)) self.assertEqual( 0, actual_loss_function_wrapper(params, jnp.array(0), jnp.array(0))) self.assertEqual( 0, actual_loss_function_wrapper(params, jnp.array(1), jnp.array(1))) self.assertEqual( 1, actual_loss_function_wrapper(params, jnp.array(0), jnp.array(1)))
def test_outputs_preserved(self): num_outputs = 2 initial_state, update = pop_art.popart(num_outputs, step_size=1e-3, scale_lb=1e-6, scale_ub=1e6) state = initial_state() key = jax.random.PRNGKey(428) def net(x): linear = hk.Linear(num_outputs, b_init=initializers.RandomUniform(), name='head') return linear(x) init_fn, apply_fn = hk.without_apply_rng(hk.transform(net)) key, subkey1, subkey2 = jax.random.split(key, 3) fixed_data = jax.random.uniform(subkey1, (4, 3)) params = init_fn(subkey2, fixed_data) initial_result = apply_fn(params, fixed_data) indices = np.asarray([0, 1, 0, 1, 0, 1, 0, 1]) # Repeatedly update state and verify that params still preserve outputs. for _ in range(30): key, subkey1, subkey2 = jax.random.split(key, 3) targets = jax.random.uniform(subkey1, (8, )) linear_params, state = update(params['head'], state, targets, indices) params = data_structures.to_mutable_dict(params) params['head'] = linear_params # Apply updated linear transformation and unnormalize outputs. transform = apply_fn(params, fixed_data) out = jnp.broadcast_to( state.scale, transform.shape) * transform + jnp.broadcast_to( state.shift, transform.shape) np.testing.assert_allclose(initial_result, out, atol=1e-2)
def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray, seed: int = 0) -> base.Agent: """Initialize a DQN agent with default parameters.""" def network(inputs: jnp.ndarray) -> jnp.ndarray: flat_inputs = hk.Flatten()(inputs) mlp = hk.nets.MLP([64, 64, action_spec.num_values]) action_values = mlp(flat_inputs) return action_values return DQN( obs_spec=obs_spec, action_spec=action_spec, network=hk.transform(network), optimizer=optix.adam(1e-3), batch_size=32, discount=0.99, replay_capacity=10000, min_replay_size=100, sgd_period=1, target_update_period=4, epsilon=0.05, rng=hk.PRNGSequence(seed), )
def make_q_network(spec, hidden_layer_sizes=(512, 512, 256), architecture='LayerNorm'): """DQN network for Aquadem algo.""" def _q_fn(obs): if architecture == 'MLP': # AQuaOff architecture network_fn = hk.nets.MLP elif architecture == 'LayerNorm': # Original AQuaDem architecture network_fn = networks_lib.LayerNormMLP else: return ValueError('Architecture not recognized') network = network_fn(list(hidden_layer_sizes) + [spec.actions.num_values]) value = network(obs) return value critic = hk.without_apply_rng(hk.transform(_q_fn)) dummy_obs = utils.zeros_like(spec.observations) dummy_obs = utils.add_batch_dim(dummy_obs) critic_network = networks_lib.FeedForwardNetwork( lambda key: critic.init(key, dummy_obs), critic.apply) return critic_network
def __init__( self, can_run_backwards: bool, latent_system_dim: int, latent_system_net_type: str, latent_system_kwargs: Dict[str, Any], encoder_aggregation_type: Optional[str], decoder_de_aggregation_type: Optional[str], encoder_kwargs: Dict[str, Any], decoder_kwargs: Dict[str, Any], num_inference_steps: int, num_target_steps: int, name: str, latent_spatial_shape: Optional[Tuple[int, int]] = (4, 4), has_latent_transform: bool = False, latent_transform_kwargs: Optional[Dict[str, Any]] = None, rescale_by: Optional[str] = "pixels_and_time", data_format: str = "NHWC", **unused_kwargs ): # Arguments checks encoder_kwargs = encoder_kwargs or dict() decoder_kwargs = decoder_kwargs or dict() # Set the decoder de-aggregation type the "same" type as the encoder if not # provided if (decoder_de_aggregation_type is None and encoder_aggregation_type is not None): if encoder_aggregation_type == "linear_projection": decoder_de_aggregation_type = "linear_projection" elif encoder_aggregation_type in ("mean", "max"): decoder_de_aggregation_type = "tile" else: raise ValueError(f"Unrecognized encoder_aggregation_type=" f"{encoder_aggregation_type}") if latent_system_net_type == "conv": if encoder_aggregation_type is not None: raise ValueError("When the latent system is convolutional, the encoder " "aggregation type should be None.") if decoder_de_aggregation_type is not None: raise ValueError("When the latent system is convolutional, the decoder " "aggregation type should be None.") else: if encoder_aggregation_type is None: raise ValueError("When the latent system is not convolutional, the " "you must provide an encoder aggregation type.") if decoder_de_aggregation_type is None: raise ValueError("When the latent system is not convolutional, the " "you must provide an decoder aggregation type.") if has_latent_transform and latent_transform_kwargs is None: raise ValueError("When using latent transformation you have to provide " "the latent_transform_kwargs argument.") if unused_kwargs: logging.warning("Unused kwargs: %s", str(unused_kwargs)) super().__init__(**unused_kwargs) self.can_run_backwards = can_run_backwards self.latent_system_dim = latent_system_dim self.latent_system_kwargs = latent_system_kwargs self.latent_system_net_type = latent_system_net_type self.latent_spatial_shape = latent_spatial_shape self.num_inference_steps = num_inference_steps self.num_target_steps = num_target_steps self.rescale_by = rescale_by self.data_format = data_format self.name = name # Encoder self.encoder_kwargs = encoder_kwargs self.encoder = hk.transform( lambda *args, **kwargs: networks.SpatialConvEncoder( # pylint: disable=unnecessary-lambda,g-long-lambda latent_dim=latent_system_dim, aggregation_type=encoder_aggregation_type, data_format=data_format, name="Encoder", **encoder_kwargs )(*args, **kwargs)) # Decoder self.decoder_kwargs = decoder_kwargs self.decoder = hk.transform( lambda *args, **kwargs: networks.SpatialConvDecoder( # pylint: disable=unnecessary-lambda,g-long-lambda initial_spatial_shape=self.latent_spatial_shape, de_aggregation_type=decoder_de_aggregation_type, data_format=data_format, max_de_aggregation_dims=self.latent_system_dim // 2, name="Decoder", **decoder_kwargs, )(*args, **kwargs)) self.has_latent_transform = has_latent_transform if has_latent_transform: self.latent_transform = hk.transform( lambda *args, **kwargs: networks.make_flexible_net( # pylint: disable=unnecessary-lambda,g-long-lambda net_type=latent_system_net_type, output_dims=latent_system_dim, name="LatentTransform", **latent_transform_kwargs )(*args, **kwargs)) else: self.latent_transform = None self._jit_init = None
def main(argv): """Trains Prioritized DQN agent on Atari.""" del argv logging.info('Prioritized DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari( FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.double_dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation chex.assert_shape(sample_network_input, (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames)) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value) # Note the t in the replay is not exactly aligned with the agent t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * int(FLAGS.num_train_frames / FLAGS.num_action_repeats)), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.PrioritizedDqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. checkpoint = parts.NullCheckpoint() state = checkpoint.state state.iteration = 0 state.train_agent = train_agent state.eval_agent = eval_agent state.random_state = random_state state.writer = writer if checkpoint.can_be_restored(): checkpoint.restore() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_trackers = parts.make_default_trackers(train_agent) train_stats = parts.generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration %d.', state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_trackers = parts.make_default_trackers(eval_agent) eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', state.iteration, '%3d'), ('frame', state.iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'), ('train_state_value', train_stats['state_value'], '%.3f'), ('importance_sampling_exponent', train_agent.importance_sampling_exponent, '%.3f'), ('max_seen_priority', train_agent.max_seen_priority, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ] log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def fit_full(self, config: MLPTrainingConfig) -> MLP: if config.best_epoch is None: raise ValueError("best epoch not specified by MLP Config") rng_key = jax.random.PRNGKey(0) mlp_function = hk.transform(lambda x, training: (create_mlp( self.embedding_train.shape[1], config, ))(x, training)) X = sps.vstack([self.profile_train, self.profile_test]) y = jnp.concatenate([self.embedding_train, self.embedding_test], axis=0) mb_size = 128 rng_key, sub_key = jax.random.split(rng_key) params = mlp_function.init( sub_key, jnp.zeros((1, self.profile_train.shape[1]), dtype=jnp.float32), True, ) opt = optax.adam(config.learning_rate) opt_state = opt.init(params) @partial(jax.jit, static_argnums=(3, )) def predict(params: hk.Params, rng: PRNGKey, X: jnp.ndarray, training: bool) -> jnp.ndarray: return mlp_function.apply(params, rng, X, training) @partial(jax.jit, static_argnums=(4, )) def loss_fn( params: hk.Params, rng: PRNGKey, X: jnp.ndarray, Y: jnp.ndarray, training: bool, ) -> jnp.ndarray: prediction = predict(params, rng, X, training) return ((Y - prediction)**2).mean(axis=1).sum() @jax.jit def update( params: hk.Params, rng: PRNGKey, opt_state: optax.OptState, X: jnp.ndarray, Y: jnp.ndarray, ) -> Tuple[jnp.ndarray, hk.Params, optax.OptState]: loss_value = loss_fn(params, rng, X, Y, True) grad = jax.grad(loss_fn)(params, rng, X, Y, True) updates, opt_state = opt.update(grad, opt_state) new_params = optax.apply_updates(params, updates) return loss_value, new_params, opt_state mb_size = 128 for _ in tqdm(range(config.best_epoch)): train_loss = 0 for X_mb, y_mb, _ in self.stream(X, y, mb_size): rng_key, sub_key = jax.random.split(rng_key) loss_value, params, opt_state = update(params, sub_key, opt_state, X_mb, y_mb) train_loss += loss_value train_loss /= self.profile_train.shape[0] return MLP(predict, params)
def __init__( self, obs_spec: specs.Array, unroll_fn: networks_lib.PolicyValueRNN, initial_state_fn: Callable[[], hk.LSTMState], iterator: Iterator[reverb.ReplaySample], optimizer: optax.GradientTransformation, random_key: networks_lib.PRNGKey, discount: float = 0.99, entropy_cost: float = 0., baseline_cost: float = 1., max_abs_reward: float = np.inf, counter: counting.Counter = None, logger: loggers.Logger = None, devices: Optional[Sequence[jax.xla.Device]] = None, prefetch_size: int = 2, num_prefetch_threads: Optional[int] = None, ): self._devices = devices or jax.local_devices() # Transform into pure functions. unroll_fn = hk.without_apply_rng(hk.transform(unroll_fn, apply_rng=True)) initial_state_fn = hk.without_apply_rng( hk.transform(initial_state_fn, apply_rng=True)) loss_fn = losses.impala_loss( unroll_fn, discount=discount, max_abs_reward=max_abs_reward, baseline_cost=baseline_cost, entropy_cost=entropy_cost) @jax.jit def sgd_step( state: TrainingState, sample: reverb.ReplaySample ) -> Tuple[TrainingState, Dict[str, jnp.ndarray]]: """Computes an SGD step, returning new state and metrics for logging.""" # Compute gradients. grad_fn = jax.value_and_grad(loss_fn) loss_value, gradients = grad_fn(state.params, sample) # Average gradients over pmap replicas before optimizer update. gradients = jax.lax.pmean(gradients, _PMAP_AXIS_NAME) # Apply updates. updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optax.apply_updates(state.params, updates) metrics = { 'loss': loss_value, } new_state = TrainingState(params=new_params, opt_state=new_opt_state) return new_state, metrics def make_initial_state(key: jnp.ndarray) -> TrainingState: """Initialises the training state (parameters and optimiser state).""" dummy_obs = utils.zeros_like(obs_spec) dummy_obs = utils.add_batch_dim(dummy_obs) # Dummy 'sequence' dim. initial_state = initial_state_fn.apply(None) initial_params = unroll_fn.init(key, dummy_obs, initial_state) initial_opt_state = optimizer.init(initial_params) return TrainingState(params=initial_params, opt_state=initial_opt_state) # Initialise training state (parameters and optimiser state). state = make_initial_state(random_key) self._state = utils.replicate_in_all_devices(state, self._devices) if num_prefetch_threads is None: num_prefetch_threads = len(self._devices) self._prefetched_iterator = utils.sharded_prefetch( iterator, buffer_size=prefetch_size, devices=devices, num_threads=num_prefetch_threads, ) self._sgd_step = jax.pmap( sgd_step, axis_name=_PMAP_AXIS_NAME, devices=self._devices) # Set up logging/counting. self._counter = counter or counting.Counter() self._logger = logger or loggers.make_default_logger('learner')
def __init__(self, network: networks.QNetwork, obs_spec: specs.Array, discount: float, importance_sampling_exponent: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optix.InitUpdate, rng: hk.PRNGSequence, max_abs_reward: float = 1., huber_loss_parameter: float = 1., replay_client: reverb.Client = None, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" # Transform network into a pure function. network = hk.transform(network) def loss(params: hk.Params, target_params: hk.Params, sample: reverb.ReplaySample): o_tm1, a_tm1, r_t, d_t, o_t = sample.data keys, probs = sample.info[:2] # Forward pass. q_tm1 = network.apply(params, o_tm1) q_t_value = network.apply(target_params, o_t) q_t_selector = network.apply(params, o_t) # Cast and clip rewards. d_t = (d_t * discount).astype(jnp.float32) r_t = jnp.clip(r_t, -max_abs_reward, max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. mean_loss = jnp.mean(importance_weights * batch_loss) # [] priorities = jnp.abs(td_error).astype(jnp.float64) return mean_loss, (keys, priorities) def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, LearnerOutputs]: grad_fn = jax.grad(loss, has_aux=True) gradients, (keys, priorities) = grad_fn(state.params, state.target_params, samples) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) new_state = TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) outputs = LearnerOutputs(keys=keys, priorities=priorities) return new_state, outputs def update_priorities(outputs: LearnerOutputs): for key, priority in zip(outputs.keys, outputs.priorities): replay_client.mutate_priorities( table=adders.DEFAULT_PRIORITY_TABLE, updates={key: priority}) # Internalise agent components (replay buffer, networks, optimizer). self._replay_client = replay_client self._iterator = utils.prefetch(iterator) # Internalise the hyperparameters. self._target_update_period = target_update_period # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Initialise parameters and optimiser state. initial_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_target_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_opt_state = optimizer.init(initial_params) self._state = TrainingState(params=initial_params, target_params=initial_target_params, opt_state=initial_opt_state, step=0) self._forward = jax.jit(network.apply) self._sgd_step = jax.jit(sgd_step) self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def main(argv): """Trains DQN agent on Atari.""" del argv logging.info("DQN on Atari on %s.", jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Key-Door environment.""" env = gym_key_door.GymKeyDoor( env_args={ constants.MAP_ASCII_PATH: FLAGS.map_ascii_path, constants.MAP_YAML_PATH: FLAGS.map_yaml_path, constants.REPRESENTATION: constants.PIXEL, constants.SCALING: FLAGS.env_scaling, constants.EPISODE_TIMEOUT: FLAGS.max_frames_per_episode, constants.GRAYSCALE: False, constants.BATCH_DIMENSION: False, constants.TORCH_AXES: False, }, env_shape=FLAGS.env_shape, ) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info("Environment: %s", FLAGS.environment_name) logging.info("Action spec: %s", env.action_spec()) logging.info("Observation spec: %s", env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation assert sample_network_input.shape == ( FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames, ) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value, ) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t), ) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t), ) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) if FLAGS.shaping_function_type == constants.NO_PENALTY: shaping_function = shaping.NoPenalty() if FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY: shaping_function = shaping.HardCodedPenalty( penalty=FLAGS.shaping_multiplicative_factor) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.Dqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, shaping_function=shaping_function, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. # checkpoint = parts.NullCheckpoint() checkpoint = parts.ImplementedCheckpoint( checkpoint_path=FLAGS.checkpoint_path) if checkpoint.can_be_restored(): checkpoint.restore() train_agent.set_state(state=checkpoint.state.train_agent) eval_agent.set_state(state=checkpoint.state.eval_agent) writer.set_state(state=checkpoint.state.writer) state = checkpoint.state state.iteration = 0 state.train_agent = train_agent.get_state() state.eval_agent = eval_agent.get_state() state.random_state = random_state state.writer = writer.get_state() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info("Training iteration %d.", state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_stats = parts.generate_statistics(train_seq_truncated) logging.info("Evaluation iteration %d.", state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = parts.generate_statistics(eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats["episode_return"]) capped_human_normalized_score = np.amin([1.0, human_normalized_score]) log_output = [ ("iteration", state.iteration, "%3d"), ("frame", state.iteration * FLAGS.num_train_frames, "%5d"), ("eval_episode_return", eval_stats["episode_return"], "% 2.2f"), ("train_episode_return", train_stats["episode_return"], "% 2.2f"), ("eval_num_episodes", eval_stats["num_episodes"], "%3d"), ("train_num_episodes", train_stats["num_episodes"], "%3d"), ("eval_frame_rate", eval_stats["step_rate"], "%4.0f"), ("train_frame_rate", train_stats["step_rate"], "%4.0f"), ("train_exploration_epsilon", train_agent.exploration_epsilon, "%.3f"), ("normalized_return", human_normalized_score, "%.3f"), ("capped_normalized_return", capped_human_normalized_score, "%.3f"), ("human_gap", 1.0 - capped_human_normalized_score, "%.3f"), ] log_output_str = ", ".join( ("%s: " + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()