def test_feedforward(self): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) def policy(inputs: jnp.ndarray): return hk.Sequential([ hk.Flatten(), hk.Linear(env_spec.actions.num_values), lambda x: jnp.argmax(x, axis=-1), ])( inputs) policy = hk.transform(policy, apply_rng=True) rng = hk.PRNGSequence(1) dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) params = policy.init(next(rng), dummy_obs) variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient(variable_source, 'policy') actor = actors.FeedForwardActor( policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def __init__( self, environment_spec: specs.EnvironmentSpec, network: hk.Transformed, config: DQNConfig, ): """Initialize the agent.""" # Data is communicated via reverb replay. reverb_replay = replay.make_reverb_prioritized_nstep_replay( environment_spec=environment_spec, n_step=config.n_step, batch_size=config.batch_size, max_replay_size=config.max_replay_size, min_replay_size=config.min_replay_size, priority_exponent=config.priority_exponent, discount=config.discount, ) self._server = reverb_replay.server optimizer = optax.chain( optax.clip_by_global_norm(config.max_gradient_norm), optax.adam(config.learning_rate), ) # The learner updates the parameters (and initializes them). learner = learning.DQNLearner( network=network, obs_spec=environment_spec.observations, rng=hk.PRNGSequence(config.seed), optimizer=optimizer, discount=config.discount, importance_sampling_exponent=config.importance_sampling_exponent, target_update_period=config.target_update_period, iterator=reverb_replay.data_iterator, replay_client=reverb_replay.client, ) # The actor selects actions according to the policy. def policy(params: hk.Params, key: jnp.ndarray, observation: jnp.ndarray) -> jnp.ndarray: action_values = network.apply(params, observation) return rlax.epsilon_greedy(config.epsilon).sample( key, action_values) actor = actors.FeedForwardActor( policy=policy, rng=hk.PRNGSequence(config.seed), variable_client=variable_utils.VariableClient(learner, ''), adder=reverb_replay.adder) super().__init__( actor=actor, learner=learner, min_observations=max(config.batch_size, config.min_replay_size), observations_per_step=config.batch_size / config.samples_per_insert, )
def test_recurrent(self, has_extras): environment = _make_fake_env() env_spec = specs.make_environment_spec(environment) output_size = env_spec.actions.num_values obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations)) rng = hk.PRNGSequence(1) @_transform_without_rng def network(inputs: jnp.ndarray, state: hk.LSTMState): return hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)])(inputs, state) @_transform_without_rng def initial_state(batch_size: Optional[int] = None): network = hk.DeepRNN( [hk.Reshape([-1], preserve_dims=1), hk.LSTM(output_size)]) return network.initial_state(batch_size) initial_state = initial_state.apply(initial_state.init(next(rng)), 1) params = network.init(next(rng), obs, initial_state) def policy( params: jnp.ndarray, key: jnp.ndarray, observation: jnp.ndarray, core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]: del key # Unused for test-case deterministic policy. action_values, core_state = network.apply(params, observation, core_state) actions = jnp.argmax(action_values, axis=-1) if has_extras: return (actions, (action_values, )), core_state else: return actions, core_state variable_source = fakes.VariableSource(params) variable_client = variable_utils.VariableClient( variable_source, 'policy') actor = actors.RecurrentActor(policy, hk.PRNGSequence(1), initial_state, variable_client, has_extras=has_extras) loop = environment_loop.EnvironmentLoop(environment, actor) loop.run(20)
def objective_func(self, params, state, hyperparams, rng, transition_batch, Adv): rngs = hk.PRNGSequence(rng) # get distribution params from function approximator S = self.pi.observation_preprocessor(next(rngs), transition_batch.S) dist_params, state_new = self.pi.function(params, state, next(rngs), S, True) # compute objective: q(s, a_greedy) S = self.q_targ.observation_preprocessor(next(rngs), transition_batch.S) A = self.pi.proba_dist.mode(dist_params) log_pi = self.pi.proba_dist.log_proba(dist_params, A) params_q, state_q = hyperparams['q']['params'], hyperparams['q'][ 'function_state'] Q, _ = self.q_targ.function_type1(params_q, state_q, next(rngs), S, A, True) # clip importance weights to reduce variance W = jnp.clip(transition_batch.W, 0.1, 10.) # the objective chex.assert_equal_shape([W, Q]) chex.assert_rank([W, Q], 1) objective = W * Q return jnp.mean(objective), (dist_params, log_pi, state_new)
def objective_func(self, params, state, hyperparams, rng, transition_batch, Adv): rngs = hk.PRNGSequence(rng) # get distribution params from function approximator S = self.pi.observation_preprocessor(next(rngs), transition_batch.S) dist_params, state_new = self.pi.function(params, state, next(rngs), S, True) # compute probability ratios A = self.pi.proba_dist.preprocess_variate(next(rngs), transition_batch.A) log_pi = self.pi.proba_dist.log_proba(dist_params, A) ratio = jnp.exp(log_pi - transition_batch.logP) # π_new / π_old ratio_clip = jnp.clip(ratio, 1 - hyperparams['epsilon'], 1 + hyperparams['epsilon']) # clip importance weights to reduce variance W = jnp.clip(transition_batch.W, 0.1, 10.) # ppo-clip objective chex.assert_equal_shape([W, Adv, ratio, ratio_clip]) chex.assert_rank([W, Adv, ratio, ratio_clip], 1) objective = W * jnp.minimum(Adv * ratio, Adv * ratio_clip) # also pass auxiliary data to avoid multiple forward passes return jnp.mean(objective), (dist_params, log_pi, state_new)
def target_func(self, target_params, target_state, rng, transition_batch): rngs = hk.PRNGSequence(rng) if isinstance(self.q.action_space, Discrete): # get greedy action as the argmax over q_targ params, state = target_params['q_targ'], target_state['q_targ'] S_next = self.q_targ.observation_preprocessor( next(rngs), transition_batch.S_next) Q_s_next, _ = self.q_targ.function_type2(params, state, next(rngs), S_next, False) assert Q_s_next.ndim == 2, f"bad shape: {Q_s_next.shape}" A_next = (Q_s_next == Q_s_next.max(axis=1, keepdims=True)).astype( Q_s_next.dtype) A_next /= A_next.sum(axis=1, keepdims=True) # there may be ties else: # get greedy action as the mode of pi_targ params, state = target_params['pi_targ'], target_state['pi_targ'] S_next = self.pi_targ.observation_preprocessor( next(rngs), transition_batch.S_next) A_next = self.pi_targ.mode_func(params, state, next(rngs), S_next) # evaluate on q (not q_targ) params, state = target_params['q'], target_state['q'] S_next = self.q.observation_preprocessor(next(rngs), transition_batch.S_next) Q_sa_next, _ = self.q.function_type1(params, state, next(rngs), S_next, A_next, False) assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}" f, f_inv = self.q.value_transform.transform_func, self.q_targ.value_transform.inverse_func return f(transition_batch.Rn + transition_batch.In * f_inv(Q_sa_next))
def test_create_toy_example(self): data, model = test_util.create_toy_example( num_clients=10, num_clusters=2, num_classes=4, num_examples=5, seed=10) batch = next((data.create_tf_dataset_for_client( data.client_ids[0]).batch(3).as_numpy_iterator())) params = model.init_params(next(hk.PRNGSequence(0))) self.assertTupleEqual(model.apply_fn(params, None, batch).shape, (3, 4))
def loss_func(params, state, hyperparams, rng, transition_batch): rngs = hk.PRNGSequence(rng) S = self.model.observation_preprocessor(next(rngs), transition_batch.S) A = self.model.action_preprocessor(next(rngs), transition_batch.A) if is_stochastic(self.model): dist_params, new_state = \ self.model.function_type1(params, state, next(rngs), S, A, True) y_pred = self.model.proba_dist.sample(dist_params, next(rngs)) else: y_pred, new_state = self.model.function_type1( params, state, next(rngs), S, A, True) if is_transition_model(self.model): y_true = self.model.observation_preprocessor( next(rngs), transition_batch.S_next) elif is_reward_function(self.model): y_true = self.model.value_transform.transform_func( transition_batch.Rn) else: raise AssertionError( f"unexpected model type: {type(self.model)}") loss = self.loss_function(y_true, y_pred) td_error = -jax.grad(self.loss_function, argnums=1)(y_true, y_pred) # add regularization term if self.regularizer is not None: hparams = hyperparams['regularizer'] loss = loss + jnp.mean( self.regularizer.function(dist_params, **hparams)) return loss, (loss, td_error, new_state)
def example_data(cls, env, observation_preprocessor, action_preprocessor, proba_dist, batch_size=1, random_seed=None): if not isinstance(env.observation_space, Space): raise TypeError( "env.observation_space must be derived from gym.Space, " f"got: {type(env.observation_space)}") if not isinstance(env.action_space, Space): raise TypeError( f"env.action_space must be derived from gym.Space, got: {type(env.action_space)}" ) rnd = onp.random.RandomState(random_seed) rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max)) # these must be provided assert observation_preprocessor is not None assert action_preprocessor is not None assert proba_dist is not None # input: state observations S = [ safe_sample(env.observation_space, rnd) for _ in range(batch_size) ] S = [observation_preprocessor(next(rngs), s) for s in S] S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S) # input: actions A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)] A = [action_preprocessor(next(rngs), a) for a in A] A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A) # output: type1 dist_params_type1 = jax.tree_map( lambda x: jnp.asarray(rnd.randn(batch_size, *x.shape[1:])), proba_dist.default_priors) data_type1 = ExampleData(inputs=Inputs(args=ArgsType1( S=S, A=A, is_training=True), static_argnums=(2, )), output=dist_params_type1) if not isinstance(env.action_space, Discrete): return ModelTypes(type1=data_type1, type2=None) # output: type2 (if actions are discrete) dist_params_type2 = jax.tree_map( lambda x: jnp.asarray( rnd.randn(batch_size, env.action_space.n, *x.shape[1:])), proba_dist.default_priors) data_type2 = ExampleData(inputs=Inputs(args=ArgsType2( S=S, is_training=True), static_argnums=(1, )), output=dist_params_type2) return ModelTypes(type1=data_type1, type2=data_type2)
def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray, seed: int = 0) -> base.Agent: """Creates an actor-critic agent with default hyperparameters.""" def network(inputs: jnp.ndarray) -> Tuple[Logits, Value]: flat_inputs = hk.Flatten()(inputs) torso = hk.nets.MLP([64, 64]) policy_head = hk.Linear(action_spec.num_values) value_head = hk.Linear(1) embedding = torso(flat_inputs) logits = policy_head(embedding) value = value_head(embedding) return logits, jnp.squeeze(value, axis=-1) return ActorCritic( obs_spec=obs_spec, action_spec=action_spec, network=network, optimizer=optix.adam(3e-3), rng=hk.PRNGSequence(seed), sequence_length=32, discount=0.99, td_lambda=0.9, )
def postprocess_variate(self, rng, X, index=0, batch_mode=False): rngs = hk.PRNGSequence(rng) if self._structure_type == StructureType.LEAF: return self._structure.postprocess_variate( next(rngs), X, index=index, batch_mode=batch_mode) if isinstance(self.space, (gym.spaces.MultiDiscrete, gym.spaces.MultiBinary)): assert self._structure_type == StructureType.LIST return onp.stack([ dist.postprocess_variate(next(rngs), X[i], index=index, batch_mode=batch_mode) for i, dist in enumerate(self._structure)], axis=-1) if isinstance(self.space, gym.spaces.Tuple): assert self._structure_type == StructureType.LIST return tuple( dist.postprocess_variate(next(rngs), X[i], index=index, batch_mode=batch_mode) for i, dist in enumerate(self._structure)) if isinstance(self.space, gym.spaces.Dict): assert self._structure_type == StructureType.DICT return { k: dist.postprocess_variate(next(rngs), X[k], index=index, batch_mode=batch_mode) for k, dist in self._structure.items()} raise AssertionError( f"postprocess_variate not implemented for space: {self.space.__class__.__name__}; " "please send us a bug report / feature request")
def default_agent(obs_spec: specs.Array, action_spec: specs.DiscreteArray, seed: int = 0) -> base.Agent: """Creates an actor-critic agent with default hyperparameters.""" hidden_size = 256 initial_rnn_state = hk.LSTMState( hidden=jnp.zeros((1, hidden_size), dtype=jnp.float32), cell=jnp.zeros((1, hidden_size), dtype=jnp.float32)) def network(inputs: jnp.ndarray, state) -> Tuple[Tuple[Logits, Value], LSTMState]: flat_inputs = hk.Flatten()(inputs) torso = hk.nets.MLP([hidden_size, hidden_size]) lstm = hk.LSTM(hidden_size) policy_head = hk.Linear(action_spec.num_values) value_head = hk.Linear(1) embedding = torso(flat_inputs) embedding, state = lstm(embedding, state) logits = policy_head(embedding) value = value_head(embedding) return (logits, jnp.squeeze(value, axis=-1)), state return ActorCriticRNN( obs_spec=obs_spec, action_spec=action_spec, network=network, initial_rnn_state=initial_rnn_state, optimizer=optix.adam(3e-3), rng=hk.PRNGSequence(seed), sequence_length=32, discount=0.99, td_lambda=0.9, )
def sample_func(params, state, rng, S): rngs = hk.PRNGSequence(rng) dist_params, _ = self.function(params, state, next(rngs), S, False) X = self.proba_dist.sample(dist_params, next(rngs)) logP = self.proba_dist.log_proba(dist_params, X) return X, logP
def example_data(cls, env, observation_preprocessor=None, batch_size=1, random_seed=None): if not isinstance(env.observation_space, Space): raise TypeError( "env.observation_space must be derived from gym.Space, " f"got: {type(env.observation_space)}") if observation_preprocessor is None: observation_preprocessor = default_preprocessor( env.observation_space) rnd = onp.random.RandomState(random_seed) rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max)) # input: state observations S = [ safe_sample(env.observation_space, rnd) for _ in range(batch_size) ] S = [observation_preprocessor(next(rngs), s) for s in S] S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S) return ExampleData( inputs=Inputs(args=ArgsType2(S=S, is_training=True), static_argnums=(1, )), output=jnp.asarray(rnd.randn(batch_size)), )
def train_model(ds, attention_fn, position_enc_fn): logdir = "./logs/" global net, opt np.random.seed(cfg['rng_seed']) tf.random.set_seed(cfg['rng_seed']) # For loading / shuffling of dset rng_seq = hk.PRNGSequence(cfg['rng_seed']) test_image = jnp.asarray(next(ds)[-1], dtype=jnp.float32)[None, :, :, :] / 255.0 reco_key = jax.random.PRNGKey( cfg['rng_seed'] + 1) # Naughty things will happen if we try to adjust the # Initialize network and optimizer net = hk.transform( partial(forward_fn, attention_fn=attention_fn, position_enc_fn=position_enc_fn, cfg=cfg)) params = net.init(next(rng_seq), test_image) print("Network Initialized") print("Model has " + str(hk.data_structures.tree_size(params)) + " parameters") opt = get_optimizer(cfg) opt_state = opt.init(params) # Train file_writer = tf.summary.create_file_writer(logdir) with file_writer.as_default(): tf.summary.image("Training Source", test_image, step=0) test_image = (test_image - 0.5) * 2 step = 0 print("Training Starting") while step < 8E+4: step += 1 batch = next(ds) batch = ((jnp.asarray(batch, dtype=jnp.float32) / 255.) - 0.5) * 2. # Do SGD on a batch of training examples. loss, params, opt_state = update(params, next(rng_seq), opt_state, batch) # Apply model on test sequence for tensorboard if step % 500 == 0: # Log a reconstruction and accompanying attention masks reco, attn = net.apply(params, reco_key, (test_image, True)) reco = (reco / 2.) + 0.5 # Horitontally stack masks attn = np.expand_dims(np.hstack(list(attn[0].T.reshape(4, 35, 35))), axis=(0, -1)) with file_writer.as_default(): tf.summary.image("Training Reco", reco, step=step) tf.summary.image("Attention Masks", attn, step=step) if step % 100 == 0: with file_writer.as_default(): tf.summary.scalar('loss', loss, step=step)
def context( rng: tp.Union[np.ndarray, int, None] = None, building: bool = False, get_summaries: bool = False, training: bool = True, ) -> tp.Iterator[Context]: """""" rng_sequence = PRNGSequence(rng) if rng is not None else None ctx = Context( building=building, training=training, get_summaries=get_summaries, rng_sequence=rng_sequence, losses={}, metrics={}, summaries=[], path_names_c=[], level_names_c=[], inside_call_c=[], module_c=[], index_c=[], ) LOCAL.contexts.append(ctx) if rng is not None: rng = hk.PRNGSequence(rng) try: yield ctx finally: LOCAL.contexts.pop()
def objective_func(self, params, state, hyperparams, rng, transition_batch, Adv): rngs = hk.PRNGSequence(rng) # get distribution params from function approximator S = self.pi.observation_preprocessor(next(rngs), transition_batch.S) dist_params, state_new = self.pi.function(params, state, next(rngs), S, True) # compute probability ratios A = self.pi.proba_dist.preprocess_variate(next(rngs), transition_batch.A) log_pi = self.pi.proba_dist.log_proba(dist_params, A) ratio = jnp.exp(log_pi - transition_batch.logP) # π_new / π_old ratio_clip = jnp.clip(ratio, 1 - hyperparams['epsilon'], 1 + hyperparams['epsilon']) # ppo-clip objective assert Adv.ndim == 1, f"bad shape: {Adv.shape}" assert ratio.ndim == 1, f"bad shape: {ratio.shape}" assert ratio_clip.ndim == 1, f"bad shape: {ratio_clip.shape}" objective = jnp.sum(jnp.minimum(Adv * ratio, Adv * ratio_clip)) # also pass auxiliary data to avoid multiple forward passes return objective, (dist_params, log_pi, state_new)
def quantiles_uniform(rng, batch_size, num_quantiles=32): """ Generate :code:`batch_size` quantile fractions that split the interval :math:`[0, 1]` into :code:`num_quantiles` uniformly distributed fractions. Parameters ---------- rng : jax.random.PRNGKey A pseudo-random number generator key. batch_size : int The batch size for which the quantile fractions should be generated. num_quantiles : int, optional The number of quantile fractions. By default 32. Returns ------- quantile_fractions : ndarray Array of quantile fractions. """ rngs = hk.PRNGSequence(rng) quantile_fractions = jax.random.uniform(next(rngs), shape=(batch_size, num_quantiles)) quantile_fraction_differences = quantile_fractions / \ jnp.sum(quantile_fractions, axis=-1, keepdims=True) quantile_fractions = jnp.cumsum(quantile_fraction_differences, axis=-1) return quantile_fractions
def target_func(self, target_params, target_state, rng, transition_batch): rngs = hk.PRNGSequence(rng) # action propensities params, state = target_params['pi_targ'], target_state['pi_targ'] S_next = self.pi_targ.observation_preprocessor(next(rngs), transition_batch.S_next) dist_params, _ = self.pi_targ.function(params, state, next(rngs), S_next, False) A_next = jax.nn.softmax(dist_params['logits'], axis=-1) # only works for Discrete actions # evaluate on q_targ params, state = target_params['q_targ'], target_state['q_targ'] S_next = self.q_targ.observation_preprocessor(next(rngs), transition_batch.S_next) if is_stochastic(self.q): return self._get_target_dist_params(params, state, next(rngs), transition_batch, A_next) Q_sa_next, _ = self.q_targ.function_type1(params, state, next(rngs), S_next, A_next, False) f, f_inv = self.q.value_transform.transform_func, self.q_targ.value_transform.inverse_func return f(transition_batch.Rn + transition_batch.In * f_inv(Q_sa_next))
def __call__( self, inputs: jnp.ndarray, dropout_rate: Optional[float] = None, rng=None, ) -> jnp.ndarray: """Connects the module to some inputs. Args: inputs: A Tensor of shape `[batch_size, input_size]`. dropout_rate: Optional dropout rate. rng: Optional RNG key. Require when using dropout. Returns: output: The output of the model of size `[batch_size, output_size]`. """ if dropout_rate is not None and rng is None: raise ValueError("When using dropout an rng key must be passed.") elif dropout_rate is None and rng is not None: raise ValueError("RNG should only be passed when using dropout.") rng = hk.PRNGSequence(rng) if rng is not None else None num_layers = len(self.layers) out = inputs for i, layer in enumerate(self.layers): out = layer(out) if i < (num_layers - 1) or self.activate_final: # Only perform dropout if we are activating the output. if dropout_rate is not None: out = hk.dropout(next(rng), dropout_rate, out) out = self.activation(out) return out
def test_graph_embedding_model_runs(self): graph = jraph.GraphsTuple( nodes=np.array([[0, 1, 1], [1, 2, 0], [0, 3, 0], [0, 4, 4]], dtype=np.float32), edges=np.array([[1, 1], [2, 2], [3, 3]], dtype=np.float32), senders=np.array([0, 1, 2], dtype=np.int32), receivers=np.array([1, 2, 3], dtype=np.int32), n_node=np.array([4], dtype=np.int32), n_edge=np.array([3], dtype=np.int32), globals=None) embed_dim = 3 def forward(graph): return embedding.GraphEmbeddingModel(embed_dim=3, num_layers=2)(graph) init_fn, apply_fn = hk.without_apply_rng(hk.transform(forward)) key = hk.PRNGSequence(8) params = init_fn(next(key), graph) out = apply_fn(params, graph) self.assertEqual(out.nodes.shape, (graph.nodes.shape[0], embed_dim)) self.assertEqual(out.edges.shape, (graph.edges.shape[0], embed_dim)) np.testing.assert_array_equal(out.senders, graph.senders) np.testing.assert_array_equal(out.receivers, graph.receivers) np.testing.assert_array_equal(out.n_node, graph.n_node)
def grads_and_metrics_func( params, target_params, state, target_state, rng, transition_batch): rngs = hk.PRNGSequence(rng) grads, (loss, td_error, G, Q, state_new) = jax.grad(loss_func, has_aux=True)( params, target_params, state, target_state, next(rngs), transition_batch) # TD error relative to the target-network estimate S = self.q_targ.observation_preprocessor(next(rngs), transition_batch.S) A = self.q_targ.action_preprocessor(next(rngs), transition_batch.A) Q_targ, _ = self.q_targ.function_type1( target_params['q_targ'], target_state['q_targ'], next(rngs), S, A, False) td_error_targ = -jax.grad(self.loss_function, argnums=1)(Q, Q_targ) # e.g. (Q - Q_targ) name = self.__class__.__name__ metrics = { f'{name}/loss': loss, f'{name}/td_error': jnp.mean(td_error), f'{name}/td_error_targ': jnp.mean(td_error_targ), } # add some diagnostics of the gradients metrics.update(get_grads_diagnostics(grads, key_prefix=f'{name}/grads_')) return grads, state_new, metrics
def __init__( self, output_size: int, rng: jax.random.PRNGKey, with_bias: bool = True, w_mu_init: Optional[hk.initializers.Initializer] = None, b_mu_init: Optional[hk.initializers.Initializer] = None, w_sigma_init: Optional[hk.initializers.Initializer] = None, b_sigma_init: Optional[hk.initializers.Initializer] = None, name: Optional[str] = None, factorized_noise: bool = False ): """Constructs the Linear module. Args: output_size: Output dimensionality. with_bias: Whether to add a bias to the output. w_init: Optional initializer for weights. By default, uses random values from truncated normal, with stddev `1 / sqrt(fan_in)`. See https://arxiv.org/abs/1502.03167v3. b_init: Optional initializer for bias. By default, zero. name: Name of the module. """ super().__init__(name=name) self.rng = hk.PRNGSequence(rng) self.input_size = None self.output_size = output_size self.with_bias = with_bias self.w_mu_init = w_mu_init self.b_mu_init = b_mu_init or jnp.zeros self.w_sigma_init = w_sigma_init self.b_sigma_init = b_sigma_init or jnp.zeros self.factorized = factorized_noise
def test_graph_conditioned_transformer_learns(self): graphs = jraph.GraphsTuple( nodes=np.ones((4, 3), dtype=np.float32), edges=np.ones((3, 1), dtype=np.float32), senders=np.array([0, 2, 3], dtype=np.int32), receivers=np.array([1, 3, 2], dtype=np.int32), n_node=np.array([2, 2], dtype=np.int32), n_edge=np.array([1, 2], dtype=np.int32), globals=None, ) seqs = np.array([[1, 2, 2, 0], [1, 3, 3, 3]], dtype=np.int32) vocab_size = seqs.max() + 1 embed_dim = 8 max_graph_size = graphs.n_node.max() logging.info('Training seqs: %r', seqs) x = seqs[:, :-1] y = seqs[:, 1:] def model_fn(vocab_size, embed_dim): return models.Graph2TextTransformer( vocab_size=vocab_size, emb_dim=embed_dim, num_layers=2, num_heads=4, cutoffs=[], gnn_embed_dim=embed_dim, gnn_num_layers=2) def forward(graphs, inputs, labels, max_graph_size): input_mask = (labels != 0).astype(jnp.float32) return model_fn(vocab_size, embed_dim).loss( graphs, max_graph_size, False, inputs, labels, mask=input_mask) init_fn, apply_fn = hk.transform_with_state(forward) rng = hk.PRNGSequence(8) params, state = init_fn(next(rng), graphs, x, y, max_graph_size) def apply(*args, **kwargs): out, state = apply_fn(*args, **kwargs) return out[0], (out[1], state) apply = jax.jit(apply, static_argnums=6) optimizer = optax.chain( optax.scale_by_adam(), optax.scale(-1e-3)) opt_state = optimizer.init(params) for i in range(500): (loss, model_state), grad = jax.value_and_grad(apply, has_aux=True)( params, state, next(rng), graphs, x, y, max_graph_size) metrics, state = model_state updates, opt_state = optimizer.update(grad, opt_state, params) params = optax.apply_updates(params, updates) if (i + 1) % 100 == 0: logging.info( 'Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()}) logging.info('Loss: %.8f', loss) self.assertLess(loss, 1.0)
def main(_): optimizer = optax.adam(FLAGS.learning_rate) @jax.jit def update(params: hk.Params, prng_key: PRNGKey, opt_state: OptState, batch: Batch) -> Tuple[hk.Params, OptState]: """Single SGD update step.""" grads = jax.grad(loss_fn)(params, prng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state prng_seq = hk.PRNGSequence(42) params = log_prob.init(next(prng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE))) opt_state = optimizer.init(params) train_ds = load_dataset(tfds.Split.TRAIN, FLAGS.batch_size) valid_ds = load_dataset(tfds.Split.TEST, FLAGS.batch_size) for step in range(FLAGS.training_steps): params, opt_state = update(params, next(prng_seq), opt_state, next(train_ds)) if step % FLAGS.eval_frequency == 0: val_loss = eval_fn(params, next(valid_ds)) logging.info("STEP: %5d; Validation loss: %.3f", step, val_loss)
def test_bow_transformer_runs(self): bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1], [0, 1, 0, 0, 1, 0, 1, 0], [1, 0, 0, 0, 1, 0, 0, 1]], dtype=np.int32) seqs = np.array([[1, 2, 3, 0, 0], [2, 4, 5, 6, 0], [3, 3, 5, 1, 2]], dtype=np.int32) x = seqs[:, :-1] y = seqs[:, 1:] vocab_size = seqs.max() + 1 def forward(bow, inputs, labels): model = models.Bow2TextTransformer( vocab_size=vocab_size, emb_dim=16, num_layers=2, num_heads=4, cutoffs=[]) return model.loss(bow, inputs, labels) init_fn, apply_fn = hk.transform_with_state(forward) key = hk.PRNGSequence(8) params, state = init_fn(next(key), bow, x, y) out, _ = apply_fn(params, state, next(key), bow, x, y) loss, metrics = out logging.info('loss: %g', loss) logging.info('metrics: %r', metrics)
def sample_func_type1(params, state, rng, S, A): rngs = hk.PRNGSequence(rng) dist_params, _ = self.function_type1(params, state, next(rngs), S, A, False) S_next = self.proba_dist.sample(dist_params, next(rngs)) logP = self.proba_dist.log_proba(dist_params, S_next) return S_next, logP
def test_transformer_param_count(self): seqs = np.array([[1, 2, 3, 0, 0], [3, 3, 5, 1, 2]], dtype=np.int32) x = seqs[:, :-1] y = seqs[:, 1:] vocab_size = 267_735 def forward(inputs, labels): input_mask = (labels != 0).astype(jnp.float32) model = models.TransformerXL( vocab_size=vocab_size, emb_dim=210, num_layers=2, num_heads=10, dropout_prob=0.0, dropout_attn_prob=0.0, self_att_init_scale=0.02, dense_init_scale=0.02, dense_dim=2100, cutoffs=(20000, 40000, 200000), # WikiText-103 relative_pos_clamp_len=None, ) return model.loss(inputs, labels, mask=input_mask, cache_steps=2) init_fn, apply_fn = hk.transform_with_state(forward) key = hk.PRNGSequence(8) params, state = init_fn(next(key), x, y) out, _ = apply_fn(params, state, next(key), x, y) loss, metrics = out logging.info('loss: %g', loss) logging.info('metrics: %r', metrics) param_count = tree_size(params) self.assertEqual(param_count, 58_704_438)
def target_func(self, target_params, target_state, rng, transition_batch): rngs = hk.PRNGSequence(rng) # compute q-values params, state = target_params['q_targ'], target_state['q_targ'] S_next = self.q_targ.observation_preprocessor(next(rngs), transition_batch.S_next) Q_s_next, _ = self.q_targ.function_type2(params, state, next(rngs), S_next, False) # action propensities params, state = target_params['pi_targ'], target_state['pi_targ'] S_next = self.pi_targ.observation_preprocessor(next(rngs), transition_batch.S_next) dist_params, _ = self.pi_targ.function(params, state, next(rngs), S_next, False) P = jax.nn.softmax(dist_params['logits'], axis=-1) # project assert P.ndim == 2, f"bad shape: {P.shape}" assert Q_s_next.ndim == 2, f"bad shape: {Q_s_next.shape}" Q_sa_next = jax.vmap(jnp.dot)(P, Q_s_next) f, f_inv = self.q.value_transform.transform_func, self.q_targ.value_transform.inverse_func return f(transition_batch.Rn + transition_batch.In * f_inv(Q_sa_next))
def test_transformer_with_extra_runs(self): extra = np.array([[1, 1, 0, 0], [2, 2, 2, 2], [3, 3, 3, 0]], dtype=np.int32) seqs = np.array([[1, 2, 3, 0, 0], [2, 4, 5, 6, 0], [3, 3, 5, 1, 2]], dtype=np.int32) x = seqs[:, :-1] y = seqs[:, 1:] vocab_size = seqs.max() + 1 extra_vocab_size = extra.max() + 1 def forward(inputs, labels, extra): input_mask = (labels != 0).astype(jnp.float32) extra_mask = (extra != 0).astype(jnp.float32) extra = hk.Embed(vocab_size=extra_vocab_size, embed_dim=16)(extra) model = models.TransformerXL( vocab_size=vocab_size, emb_dim=16, num_layers=2, num_heads=4, cutoffs=[], ) return model.loss(inputs, labels, mask=input_mask, extra=extra, extra_mask=extra_mask) init_fn, apply_fn = hk.transform_with_state(forward) key = hk.PRNGSequence(8) params, state = init_fn(next(key), x, y, extra) out, _ = apply_fn(params, state, next(key), x, y, extra) loss, metrics = out logging.info('loss: %g', loss) logging.info('metrics: %r', metrics)