Exemple #1
0
  def test_feedforward(self):
    environment = _make_fake_env()
    env_spec = specs.make_environment_spec(environment)

    def policy(inputs: jnp.ndarray):
      return hk.Sequential([
          hk.Flatten(),
          hk.Linear(env_spec.actions.num_values),
          lambda x: jnp.argmax(x, axis=-1),
      ])(
          inputs)

    policy = hk.transform(policy, apply_rng=True)

    rng = hk.PRNGSequence(1)
    dummy_obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
    params = policy.init(next(rng), dummy_obs)

    variable_source = fakes.VariableSource(params)
    variable_client = variable_utils.VariableClient(variable_source, 'policy')

    actor = actors.FeedForwardActor(
        policy.apply, rng=hk.PRNGSequence(1), variable_client=variable_client)

    loop = environment_loop.EnvironmentLoop(environment, actor)
    loop.run(20)
Exemple #2
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: hk.Transformed,
        config: DQNConfig,
    ):
        """Initialize the agent."""
        # Data is communicated via reverb replay.
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            environment_spec=environment_spec,
            n_step=config.n_step,
            batch_size=config.batch_size,
            max_replay_size=config.max_replay_size,
            min_replay_size=config.min_replay_size,
            priority_exponent=config.priority_exponent,
            discount=config.discount,
        )
        self._server = reverb_replay.server

        optimizer = optax.chain(
            optax.clip_by_global_norm(config.max_gradient_norm),
            optax.adam(config.learning_rate),
        )
        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            rng=hk.PRNGSequence(config.seed),
            optimizer=optimizer,
            discount=config.discount,
            importance_sampling_exponent=config.importance_sampling_exponent,
            target_update_period=config.target_update_period,
            iterator=reverb_replay.data_iterator,
            replay_client=reverb_replay.client,
        )

        # The actor selects actions according to the policy.
        def policy(params: hk.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(config.epsilon).sample(
                key, action_values)

        actor = actors.FeedForwardActor(
            policy=policy,
            rng=hk.PRNGSequence(config.seed),
            variable_client=variable_utils.VariableClient(learner, ''),
            adder=reverb_replay.adder)

        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=max(config.batch_size, config.min_replay_size),
            observations_per_step=config.batch_size /
            config.samples_per_insert,
        )
Exemple #3
0
    def test_recurrent(self, has_extras):
        environment = _make_fake_env()
        env_spec = specs.make_environment_spec(environment)
        output_size = env_spec.actions.num_values
        obs = utils.add_batch_dim(utils.zeros_like(env_spec.observations))
        rng = hk.PRNGSequence(1)

        @_transform_without_rng
        def network(inputs: jnp.ndarray, state: hk.LSTMState):
            return hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])(inputs, state)

        @_transform_without_rng
        def initial_state(batch_size: Optional[int] = None):
            network = hk.DeepRNN(
                [hk.Reshape([-1], preserve_dims=1),
                 hk.LSTM(output_size)])
            return network.initial_state(batch_size)

        initial_state = initial_state.apply(initial_state.init(next(rng)), 1)
        params = network.init(next(rng), obs, initial_state)

        def policy(
                params: jnp.ndarray, key: jnp.ndarray,
                observation: jnp.ndarray,
                core_state: hk.LSTMState) -> Tuple[jnp.ndarray, hk.LSTMState]:
            del key  # Unused for test-case deterministic policy.
            action_values, core_state = network.apply(params, observation,
                                                      core_state)
            actions = jnp.argmax(action_values, axis=-1)
            if has_extras:
                return (actions, (action_values, )), core_state
            else:
                return actions, core_state

        variable_source = fakes.VariableSource(params)
        variable_client = variable_utils.VariableClient(
            variable_source, 'policy')

        actor = actors.RecurrentActor(policy,
                                      hk.PRNGSequence(1),
                                      initial_state,
                                      variable_client,
                                      has_extras=has_extras)

        loop = environment_loop.EnvironmentLoop(environment, actor)
        loop.run(20)
Exemple #4
0
    def objective_func(self, params, state, hyperparams, rng, transition_batch,
                       Adv):
        rngs = hk.PRNGSequence(rng)

        # get distribution params from function approximator
        S = self.pi.observation_preprocessor(next(rngs), transition_batch.S)
        dist_params, state_new = self.pi.function(params, state, next(rngs), S,
                                                  True)

        # compute objective: q(s, a_greedy)
        S = self.q_targ.observation_preprocessor(next(rngs),
                                                 transition_batch.S)
        A = self.pi.proba_dist.mode(dist_params)
        log_pi = self.pi.proba_dist.log_proba(dist_params, A)
        params_q, state_q = hyperparams['q']['params'], hyperparams['q'][
            'function_state']
        Q, _ = self.q_targ.function_type1(params_q, state_q, next(rngs), S, A,
                                          True)

        # clip importance weights to reduce variance
        W = jnp.clip(transition_batch.W, 0.1, 10.)

        # the objective
        chex.assert_equal_shape([W, Q])
        chex.assert_rank([W, Q], 1)
        objective = W * Q

        return jnp.mean(objective), (dist_params, log_pi, state_new)
Exemple #5
0
    def objective_func(self, params, state, hyperparams, rng, transition_batch,
                       Adv):
        rngs = hk.PRNGSequence(rng)

        # get distribution params from function approximator
        S = self.pi.observation_preprocessor(next(rngs), transition_batch.S)
        dist_params, state_new = self.pi.function(params, state, next(rngs), S,
                                                  True)

        # compute probability ratios
        A = self.pi.proba_dist.preprocess_variate(next(rngs),
                                                  transition_batch.A)
        log_pi = self.pi.proba_dist.log_proba(dist_params, A)
        ratio = jnp.exp(log_pi - transition_batch.logP)  # π_new / π_old
        ratio_clip = jnp.clip(ratio, 1 - hyperparams['epsilon'],
                              1 + hyperparams['epsilon'])

        # clip importance weights to reduce variance
        W = jnp.clip(transition_batch.W, 0.1, 10.)

        # ppo-clip objective
        chex.assert_equal_shape([W, Adv, ratio, ratio_clip])
        chex.assert_rank([W, Adv, ratio, ratio_clip], 1)
        objective = W * jnp.minimum(Adv * ratio, Adv * ratio_clip)

        # also pass auxiliary data to avoid multiple forward passes
        return jnp.mean(objective), (dist_params, log_pi, state_new)
Exemple #6
0
    def target_func(self, target_params, target_state, rng, transition_batch):
        rngs = hk.PRNGSequence(rng)

        if isinstance(self.q.action_space, Discrete):
            # get greedy action as the argmax over q_targ
            params, state = target_params['q_targ'], target_state['q_targ']
            S_next = self.q_targ.observation_preprocessor(
                next(rngs), transition_batch.S_next)
            Q_s_next, _ = self.q_targ.function_type2(params, state, next(rngs),
                                                     S_next, False)
            assert Q_s_next.ndim == 2, f"bad shape: {Q_s_next.shape}"
            A_next = (Q_s_next == Q_s_next.max(axis=1, keepdims=True)).astype(
                Q_s_next.dtype)
            A_next /= A_next.sum(axis=1, keepdims=True)  # there may be ties

        else:
            # get greedy action as the mode of pi_targ
            params, state = target_params['pi_targ'], target_state['pi_targ']
            S_next = self.pi_targ.observation_preprocessor(
                next(rngs), transition_batch.S_next)
            A_next = self.pi_targ.mode_func(params, state, next(rngs), S_next)

        # evaluate on q (not q_targ)
        params, state = target_params['q'], target_state['q']
        S_next = self.q.observation_preprocessor(next(rngs),
                                                 transition_batch.S_next)
        Q_sa_next, _ = self.q.function_type1(params, state, next(rngs), S_next,
                                             A_next, False)

        assert Q_sa_next.ndim == 1, f"bad shape: {Q_sa_next.shape}"
        f, f_inv = self.q.value_transform.transform_func, self.q_targ.value_transform.inverse_func
        return f(transition_batch.Rn + transition_batch.In * f_inv(Q_sa_next))
Exemple #7
0
 def test_create_toy_example(self):
   data, model = test_util.create_toy_example(
       num_clients=10, num_clusters=2, num_classes=4, num_examples=5, seed=10)
   batch = next((data.create_tf_dataset_for_client(
       data.client_ids[0]).batch(3).as_numpy_iterator()))
   params = model.init_params(next(hk.PRNGSequence(0)))
   self.assertTupleEqual(model.apply_fn(params, None, batch).shape, (3, 4))
Exemple #8
0
        def loss_func(params, state, hyperparams, rng, transition_batch):
            rngs = hk.PRNGSequence(rng)
            S = self.model.observation_preprocessor(next(rngs),
                                                    transition_batch.S)
            A = self.model.action_preprocessor(next(rngs), transition_batch.A)
            if is_stochastic(self.model):
                dist_params, new_state = \
                    self.model.function_type1(params, state, next(rngs), S, A, True)
                y_pred = self.model.proba_dist.sample(dist_params, next(rngs))
            else:
                y_pred, new_state = self.model.function_type1(
                    params, state, next(rngs), S, A, True)

            if is_transition_model(self.model):
                y_true = self.model.observation_preprocessor(
                    next(rngs), transition_batch.S_next)
            elif is_reward_function(self.model):
                y_true = self.model.value_transform.transform_func(
                    transition_batch.Rn)
            else:
                raise AssertionError(
                    f"unexpected model type: {type(self.model)}")

            loss = self.loss_function(y_true, y_pred)
            td_error = -jax.grad(self.loss_function, argnums=1)(y_true, y_pred)

            # add regularization term
            if self.regularizer is not None:
                hparams = hyperparams['regularizer']
                loss = loss + jnp.mean(
                    self.regularizer.function(dist_params, **hparams))

            return loss, (loss, td_error, new_state)
    def example_data(cls,
                     env,
                     observation_preprocessor,
                     action_preprocessor,
                     proba_dist,
                     batch_size=1,
                     random_seed=None):

        if not isinstance(env.observation_space, Space):
            raise TypeError(
                "env.observation_space must be derived from gym.Space, "
                f"got: {type(env.observation_space)}")
        if not isinstance(env.action_space, Space):
            raise TypeError(
                f"env.action_space must be derived from gym.Space, got: {type(env.action_space)}"
            )

        rnd = onp.random.RandomState(random_seed)
        rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max))

        # these must be provided
        assert observation_preprocessor is not None
        assert action_preprocessor is not None
        assert proba_dist is not None

        # input: state observations
        S = [
            safe_sample(env.observation_space, rnd) for _ in range(batch_size)
        ]
        S = [observation_preprocessor(next(rngs), s) for s in S]
        S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)

        # input: actions
        A = [safe_sample(env.action_space, rnd) for _ in range(batch_size)]
        A = [action_preprocessor(next(rngs), a) for a in A]
        A = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *A)

        # output: type1
        dist_params_type1 = jax.tree_map(
            lambda x: jnp.asarray(rnd.randn(batch_size, *x.shape[1:])),
            proba_dist.default_priors)
        data_type1 = ExampleData(inputs=Inputs(args=ArgsType1(
            S=S, A=A, is_training=True),
                                               static_argnums=(2, )),
                                 output=dist_params_type1)

        if not isinstance(env.action_space, Discrete):
            return ModelTypes(type1=data_type1, type2=None)

        # output: type2 (if actions are discrete)
        dist_params_type2 = jax.tree_map(
            lambda x: jnp.asarray(
                rnd.randn(batch_size, env.action_space.n, *x.shape[1:])),
            proba_dist.default_priors)
        data_type2 = ExampleData(inputs=Inputs(args=ArgsType2(
            S=S, is_training=True),
                                               static_argnums=(1, )),
                                 output=dist_params_type2)

        return ModelTypes(type1=data_type1, type2=data_type2)
Exemple #10
0
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Creates an actor-critic agent with default hyperparameters."""

  def network(inputs: jnp.ndarray) -> Tuple[Logits, Value]:
    flat_inputs = hk.Flatten()(inputs)
    torso = hk.nets.MLP([64, 64])
    policy_head = hk.Linear(action_spec.num_values)
    value_head = hk.Linear(1)
    embedding = torso(flat_inputs)
    logits = policy_head(embedding)
    value = value_head(embedding)
    return logits, jnp.squeeze(value, axis=-1)

  return ActorCritic(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      optimizer=optix.adam(3e-3),
      rng=hk.PRNGSequence(seed),
      sequence_length=32,
      discount=0.99,
      td_lambda=0.9,
  )
Exemple #11
0
    def postprocess_variate(self, rng, X, index=0, batch_mode=False):
        rngs = hk.PRNGSequence(rng)

        if self._structure_type == StructureType.LEAF:
            return self._structure.postprocess_variate(
                next(rngs), X, index=index, batch_mode=batch_mode)

        if isinstance(self.space, (gym.spaces.MultiDiscrete, gym.spaces.MultiBinary)):
            assert self._structure_type == StructureType.LIST
            return onp.stack([
                dist.postprocess_variate(next(rngs), X[i], index=index, batch_mode=batch_mode)
                for i, dist in enumerate(self._structure)], axis=-1)

        if isinstance(self.space, gym.spaces.Tuple):
            assert self._structure_type == StructureType.LIST
            return tuple(
                dist.postprocess_variate(next(rngs), X[i], index=index, batch_mode=batch_mode)
                for i, dist in enumerate(self._structure))

        if isinstance(self.space, gym.spaces.Dict):
            assert self._structure_type == StructureType.DICT
            return {
                k: dist.postprocess_variate(next(rngs), X[k], index=index, batch_mode=batch_mode)
                for k, dist in self._structure.items()}

        raise AssertionError(
            f"postprocess_variate not implemented for space: {self.space.__class__.__name__}; "
            "please send us a bug report / feature request")
Exemple #12
0
def default_agent(obs_spec: specs.Array,
                  action_spec: specs.DiscreteArray,
                  seed: int = 0) -> base.Agent:
  """Creates an actor-critic agent with default hyperparameters."""

  hidden_size = 256
  initial_rnn_state = hk.LSTMState(
      hidden=jnp.zeros((1, hidden_size), dtype=jnp.float32),
      cell=jnp.zeros((1, hidden_size), dtype=jnp.float32))

  def network(inputs: jnp.ndarray,
              state) -> Tuple[Tuple[Logits, Value], LSTMState]:
    flat_inputs = hk.Flatten()(inputs)
    torso = hk.nets.MLP([hidden_size, hidden_size])
    lstm = hk.LSTM(hidden_size)
    policy_head = hk.Linear(action_spec.num_values)
    value_head = hk.Linear(1)

    embedding = torso(flat_inputs)
    embedding, state = lstm(embedding, state)
    logits = policy_head(embedding)
    value = value_head(embedding)
    return (logits, jnp.squeeze(value, axis=-1)), state

  return ActorCriticRNN(
      obs_spec=obs_spec,
      action_spec=action_spec,
      network=network,
      initial_rnn_state=initial_rnn_state,
      optimizer=optix.adam(3e-3),
      rng=hk.PRNGSequence(seed),
      sequence_length=32,
      discount=0.99,
      td_lambda=0.9,
  )
Exemple #13
0
 def sample_func(params, state, rng, S):
     rngs = hk.PRNGSequence(rng)
     dist_params, _ = self.function(params, state, next(rngs), S,
                                    False)
     X = self.proba_dist.sample(dist_params, next(rngs))
     logP = self.proba_dist.log_proba(dist_params, X)
     return X, logP
Exemple #14
0
    def example_data(cls,
                     env,
                     observation_preprocessor=None,
                     batch_size=1,
                     random_seed=None):

        if not isinstance(env.observation_space, Space):
            raise TypeError(
                "env.observation_space must be derived from gym.Space, "
                f"got: {type(env.observation_space)}")

        if observation_preprocessor is None:
            observation_preprocessor = default_preprocessor(
                env.observation_space)

        rnd = onp.random.RandomState(random_seed)
        rngs = hk.PRNGSequence(rnd.randint(jnp.iinfo('int32').max))

        # input: state observations
        S = [
            safe_sample(env.observation_space, rnd) for _ in range(batch_size)
        ]
        S = [observation_preprocessor(next(rngs), s) for s in S]
        S = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *S)

        return ExampleData(
            inputs=Inputs(args=ArgsType2(S=S, is_training=True),
                          static_argnums=(1, )),
            output=jnp.asarray(rnd.randn(batch_size)),
        )
def train_model(ds, attention_fn, position_enc_fn):
    logdir = "./logs/"

    global net, opt
    np.random.seed(cfg['rng_seed'])
    tf.random.set_seed(cfg['rng_seed'])  # For loading / shuffling of dset
    rng_seq = hk.PRNGSequence(cfg['rng_seed'])

    test_image = jnp.asarray(next(ds)[-1],
                             dtype=jnp.float32)[None, :, :, :] / 255.0
    reco_key = jax.random.PRNGKey(
        cfg['rng_seed'] +
        1)  # Naughty things will happen if we try to adjust the

    # Initialize network and optimizer
    net = hk.transform(
        partial(forward_fn,
                attention_fn=attention_fn,
                position_enc_fn=position_enc_fn,
                cfg=cfg))
    params = net.init(next(rng_seq), test_image)
    print("Network Initialized")
    print("Model has " + str(hk.data_structures.tree_size(params)) +
          " parameters")

    opt = get_optimizer(cfg)
    opt_state = opt.init(params)

    # Train
    file_writer = tf.summary.create_file_writer(logdir)
    with file_writer.as_default():
        tf.summary.image("Training Source", test_image, step=0)
    test_image = (test_image - 0.5) * 2

    step = 0
    print("Training Starting")
    while step < 8E+4:
        step += 1
        batch = next(ds)
        batch = ((jnp.asarray(batch, dtype=jnp.float32) / 255.) - 0.5) * 2.
        # Do SGD on a batch of training examples.
        loss, params, opt_state = update(params, next(rng_seq), opt_state,
                                         batch)
        # Apply model on test sequence for tensorboard
        if step % 500 == 0:
            # Log a reconstruction and accompanying attention masks
            reco, attn = net.apply(params, reco_key, (test_image, True))
            reco = (reco / 2.) + 0.5
            # Horitontally stack masks
            attn = np.expand_dims(np.hstack(list(attn[0].T.reshape(4, 35,
                                                                   35))),
                                  axis=(0, -1))

            with file_writer.as_default():
                tf.summary.image("Training Reco", reco, step=step)
                tf.summary.image("Attention Masks", attn, step=step)

        if step % 100 == 0:
            with file_writer.as_default():
                tf.summary.scalar('loss', loss, step=step)
Exemple #16
0
def context(
    rng: tp.Union[np.ndarray, int, None] = None,
    building: bool = False,
    get_summaries: bool = False,
    training: bool = True,
) -> tp.Iterator[Context]:
    """"""

    rng_sequence = PRNGSequence(rng) if rng is not None else None

    ctx = Context(
        building=building,
        training=training,
        get_summaries=get_summaries,
        rng_sequence=rng_sequence,
        losses={},
        metrics={},
        summaries=[],
        path_names_c=[],
        level_names_c=[],
        inside_call_c=[],
        module_c=[],
        index_c=[],
    )

    LOCAL.contexts.append(ctx)

    if rng is not None:
        rng = hk.PRNGSequence(rng)

    try:
        yield ctx
    finally:
        LOCAL.contexts.pop()
Exemple #17
0
    def objective_func(self, params, state, hyperparams, rng, transition_batch,
                       Adv):
        rngs = hk.PRNGSequence(rng)

        # get distribution params from function approximator
        S = self.pi.observation_preprocessor(next(rngs), transition_batch.S)
        dist_params, state_new = self.pi.function(params, state, next(rngs), S,
                                                  True)

        # compute probability ratios
        A = self.pi.proba_dist.preprocess_variate(next(rngs),
                                                  transition_batch.A)
        log_pi = self.pi.proba_dist.log_proba(dist_params, A)
        ratio = jnp.exp(log_pi - transition_batch.logP)  # π_new / π_old
        ratio_clip = jnp.clip(ratio, 1 - hyperparams['epsilon'],
                              1 + hyperparams['epsilon'])

        # ppo-clip objective
        assert Adv.ndim == 1, f"bad shape: {Adv.shape}"
        assert ratio.ndim == 1, f"bad shape: {ratio.shape}"
        assert ratio_clip.ndim == 1, f"bad shape: {ratio_clip.shape}"
        objective = jnp.sum(jnp.minimum(Adv * ratio, Adv * ratio_clip))

        # also pass auxiliary data to avoid multiple forward passes
        return objective, (dist_params, log_pi, state_new)
Exemple #18
0
def quantiles_uniform(rng, batch_size, num_quantiles=32):
    """
    Generate :code:`batch_size` quantile fractions that split the interval :math:`[0, 1]`
    into :code:`num_quantiles` uniformly distributed fractions.

    Parameters
    ----------
    rng : jax.random.PRNGKey
        A pseudo-random number generator key.
    batch_size : int
        The batch size for which the quantile fractions should be generated.
    num_quantiles : int, optional
        The number of quantile fractions. By default 32.

    Returns
    -------
    quantile_fractions : ndarray
        Array of quantile fractions.
    """
    rngs = hk.PRNGSequence(rng)
    quantile_fractions = jax.random.uniform(next(rngs),
                                            shape=(batch_size, num_quantiles))
    quantile_fraction_differences = quantile_fractions / \
        jnp.sum(quantile_fractions, axis=-1, keepdims=True)
    quantile_fractions = jnp.cumsum(quantile_fraction_differences, axis=-1)
    return quantile_fractions
Exemple #19
0
    def target_func(self, target_params, target_state, rng, transition_batch):
        rngs = hk.PRNGSequence(rng)

        # action propensities
        params, state = target_params['pi_targ'], target_state['pi_targ']
        S_next = self.pi_targ.observation_preprocessor(next(rngs),
                                                       transition_batch.S_next)
        dist_params, _ = self.pi_targ.function(params, state, next(rngs),
                                               S_next, False)
        A_next = jax.nn.softmax(dist_params['logits'],
                                axis=-1)  # only works for Discrete actions

        # evaluate on q_targ
        params, state = target_params['q_targ'], target_state['q_targ']
        S_next = self.q_targ.observation_preprocessor(next(rngs),
                                                      transition_batch.S_next)

        if is_stochastic(self.q):
            return self._get_target_dist_params(params, state, next(rngs),
                                                transition_batch, A_next)

        Q_sa_next, _ = self.q_targ.function_type1(params, state, next(rngs),
                                                  S_next, A_next, False)
        f, f_inv = self.q.value_transform.transform_func, self.q_targ.value_transform.inverse_func
        return f(transition_batch.Rn + transition_batch.In * f_inv(Q_sa_next))
Exemple #20
0
  def __call__(
      self,
      inputs: jnp.ndarray,
      dropout_rate: Optional[float] = None,
      rng=None,
  ) -> jnp.ndarray:
    """Connects the module to some inputs.
    Args:
      inputs: A Tensor of shape `[batch_size, input_size]`.
      dropout_rate: Optional dropout rate.
      rng: Optional RNG key. Require when using dropout.
    Returns:
      output: The output of the model of size `[batch_size, output_size]`.
    """
    if dropout_rate is not None and rng is None:
      raise ValueError("When using dropout an rng key must be passed.")
    elif dropout_rate is None and rng is not None:
      raise ValueError("RNG should only be passed when using dropout.")

    rng = hk.PRNGSequence(rng) if rng is not None else None
    num_layers = len(self.layers)

    out = inputs
    for i, layer in enumerate(self.layers):
      out = layer(out)
      if i < (num_layers - 1) or self.activate_final:
        # Only perform dropout if we are activating the output.
        if dropout_rate is not None:
          out = hk.dropout(next(rng), dropout_rate, out)
        out = self.activation(out)

    return out
Exemple #21
0
  def test_graph_embedding_model_runs(self):
    graph = jraph.GraphsTuple(
        nodes=np.array([[0, 1, 1],
                        [1, 2, 0],
                        [0, 3, 0],
                        [0, 4, 4]], dtype=np.float32),
        edges=np.array([[1, 1],
                        [2, 2],
                        [3, 3]], dtype=np.float32),
        senders=np.array([0, 1, 2], dtype=np.int32),
        receivers=np.array([1, 2, 3], dtype=np.int32),
        n_node=np.array([4], dtype=np.int32),
        n_edge=np.array([3], dtype=np.int32),
        globals=None)
    embed_dim = 3

    def forward(graph):
      return embedding.GraphEmbeddingModel(embed_dim=3, num_layers=2)(graph)

    init_fn, apply_fn = hk.without_apply_rng(hk.transform(forward))
    key = hk.PRNGSequence(8)
    params = init_fn(next(key), graph)
    out = apply_fn(params, graph)

    self.assertEqual(out.nodes.shape, (graph.nodes.shape[0], embed_dim))
    self.assertEqual(out.edges.shape, (graph.edges.shape[0], embed_dim))
    np.testing.assert_array_equal(out.senders, graph.senders)
    np.testing.assert_array_equal(out.receivers, graph.receivers)
    np.testing.assert_array_equal(out.n_node, graph.n_node)
Exemple #22
0
        def grads_and_metrics_func(
                params, target_params, state, target_state, rng, transition_batch):

            rngs = hk.PRNGSequence(rng)
            grads, (loss, td_error, G, Q, state_new) = jax.grad(loss_func, has_aux=True)(
                params, target_params, state, target_state, next(rngs), transition_batch)

            # TD error relative to the target-network estimate
            S = self.q_targ.observation_preprocessor(next(rngs), transition_batch.S)
            A = self.q_targ.action_preprocessor(next(rngs), transition_batch.A)
            Q_targ, _ = self.q_targ.function_type1(
                target_params['q_targ'], target_state['q_targ'], next(rngs), S, A, False)
            td_error_targ = -jax.grad(self.loss_function, argnums=1)(Q, Q_targ)  # e.g. (Q - Q_targ)

            name = self.__class__.__name__
            metrics = {
                f'{name}/loss': loss,
                f'{name}/td_error': jnp.mean(td_error),
                f'{name}/td_error_targ': jnp.mean(td_error_targ),
            }

            # add some diagnostics of the gradients
            metrics.update(get_grads_diagnostics(grads, key_prefix=f'{name}/grads_'))

            return grads, state_new, metrics
Exemple #23
0
 def __init__(
           self,
           output_size: int,
           rng: jax.random.PRNGKey,
           with_bias: bool = True,
           w_mu_init: Optional[hk.initializers.Initializer] = None,
           b_mu_init: Optional[hk.initializers.Initializer] = None,
           w_sigma_init: Optional[hk.initializers.Initializer] = None,
           b_sigma_init: Optional[hk.initializers.Initializer] = None,
           name: Optional[str] = None,
           factorized_noise: bool = False
 ):
   """Constructs the Linear module.
   Args:
     output_size: Output dimensionality.
     with_bias: Whether to add a bias to the output.
     w_init: Optional initializer for weights. By default, uses random values
       from truncated normal, with stddev `1 / sqrt(fan_in)`. See
       https://arxiv.org/abs/1502.03167v3.
     b_init: Optional initializer for bias. By default, zero.
     name: Name of the module.
   """
   super().__init__(name=name)
   self.rng = hk.PRNGSequence(rng)
   self.input_size = None
   self.output_size = output_size
   self.with_bias = with_bias
   self.w_mu_init = w_mu_init
   self.b_mu_init = b_mu_init or jnp.zeros
   self.w_sigma_init = w_sigma_init
   self.b_sigma_init = b_sigma_init or jnp.zeros
   self.factorized = factorized_noise
Exemple #24
0
  def test_graph_conditioned_transformer_learns(self):
    graphs = jraph.GraphsTuple(
        nodes=np.ones((4, 3), dtype=np.float32),
        edges=np.ones((3, 1), dtype=np.float32),
        senders=np.array([0, 2, 3], dtype=np.int32),
        receivers=np.array([1, 3, 2], dtype=np.int32),
        n_node=np.array([2, 2], dtype=np.int32),
        n_edge=np.array([1, 2], dtype=np.int32),
        globals=None,
        )
    seqs = np.array([[1, 2, 2, 0],
                     [1, 3, 3, 3]], dtype=np.int32)
    vocab_size = seqs.max() + 1
    embed_dim = 8
    max_graph_size = graphs.n_node.max()

    logging.info('Training seqs: %r', seqs)

    x = seqs[:, :-1]
    y = seqs[:, 1:]

    def model_fn(vocab_size, embed_dim):
      return models.Graph2TextTransformer(
          vocab_size=vocab_size,
          emb_dim=embed_dim,
          num_layers=2,
          num_heads=4,
          cutoffs=[],
          gnn_embed_dim=embed_dim,
          gnn_num_layers=2)

    def forward(graphs, inputs, labels, max_graph_size):
      input_mask = (labels != 0).astype(jnp.float32)
      return model_fn(vocab_size, embed_dim).loss(
          graphs, max_graph_size, False, inputs, labels, mask=input_mask)

    init_fn, apply_fn = hk.transform_with_state(forward)
    rng = hk.PRNGSequence(8)
    params, state = init_fn(next(rng), graphs, x, y, max_graph_size)

    def apply(*args, **kwargs):
      out, state = apply_fn(*args, **kwargs)
      return out[0], (out[1], state)
    apply = jax.jit(apply, static_argnums=6)

    optimizer = optax.chain(
        optax.scale_by_adam(),
        optax.scale(-1e-3))
    opt_state = optimizer.init(params)
    for i in range(500):
      (loss, model_state), grad = jax.value_and_grad(apply, has_aux=True)(
          params, state, next(rng), graphs, x, y, max_graph_size)
      metrics, state = model_state
      updates, opt_state = optimizer.update(grad, opt_state, params)
      params = optax.apply_updates(params, updates)
      if (i + 1) % 100 == 0:
        logging.info(
            'Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()})
    logging.info('Loss: %.8f', loss)
    self.assertLess(loss, 1.0)
Exemple #25
0
def main(_):
    optimizer = optax.adam(FLAGS.learning_rate)

    @jax.jit
    def update(params: hk.Params, prng_key: PRNGKey, opt_state: OptState,
               batch: Batch) -> Tuple[hk.Params, OptState]:
        """Single SGD update step."""
        grads = jax.grad(loss_fn)(params, prng_key, batch)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state

    prng_seq = hk.PRNGSequence(42)
    params = log_prob.init(next(prng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE)))
    opt_state = optimizer.init(params)

    train_ds = load_dataset(tfds.Split.TRAIN, FLAGS.batch_size)
    valid_ds = load_dataset(tfds.Split.TEST, FLAGS.batch_size)

    for step in range(FLAGS.training_steps):
        params, opt_state = update(params, next(prng_seq), opt_state,
                                   next(train_ds))

        if step % FLAGS.eval_frequency == 0:
            val_loss = eval_fn(params, next(valid_ds))
            logging.info("STEP: %5d; Validation loss: %.3f", step, val_loss)
Exemple #26
0
  def test_bow_transformer_runs(self):
    bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1],
                    [0, 1, 0, 0, 1, 0, 1, 0],
                    [1, 0, 0, 0, 1, 0, 0, 1]], dtype=np.int32)
    seqs = np.array([[1, 2, 3, 0, 0],
                     [2, 4, 5, 6, 0],
                     [3, 3, 5, 1, 2]], dtype=np.int32)
    x = seqs[:, :-1]
    y = seqs[:, 1:]
    vocab_size = seqs.max() + 1

    def forward(bow, inputs, labels):
      model = models.Bow2TextTransformer(
          vocab_size=vocab_size,
          emb_dim=16,
          num_layers=2,
          num_heads=4,
          cutoffs=[])
      return model.loss(bow, inputs, labels)

    init_fn, apply_fn = hk.transform_with_state(forward)
    key = hk.PRNGSequence(8)
    params, state = init_fn(next(key), bow, x, y)
    out, _ = apply_fn(params, state, next(key), bow, x, y)
    loss, metrics = out

    logging.info('loss: %g', loss)
    logging.info('metrics: %r', metrics)
 def sample_func_type1(params, state, rng, S, A):
     rngs = hk.PRNGSequence(rng)
     dist_params, _ = self.function_type1(params, state, next(rngs),
                                          S, A, False)
     S_next = self.proba_dist.sample(dist_params, next(rngs))
     logP = self.proba_dist.log_proba(dist_params, S_next)
     return S_next, logP
Exemple #28
0
  def test_transformer_param_count(self):
    seqs = np.array([[1, 2, 3, 0, 0],
                     [3, 3, 5, 1, 2]], dtype=np.int32)
    x = seqs[:, :-1]
    y = seqs[:, 1:]
    vocab_size = 267_735

    def forward(inputs, labels):
      input_mask = (labels != 0).astype(jnp.float32)
      model = models.TransformerXL(
          vocab_size=vocab_size,
          emb_dim=210,
          num_layers=2,
          num_heads=10,
          dropout_prob=0.0,
          dropout_attn_prob=0.0,
          self_att_init_scale=0.02,
          dense_init_scale=0.02,
          dense_dim=2100,
          cutoffs=(20000, 40000, 200000),  # WikiText-103
          relative_pos_clamp_len=None,
      )
      return model.loss(inputs, labels, mask=input_mask, cache_steps=2)

    init_fn, apply_fn = hk.transform_with_state(forward)
    key = hk.PRNGSequence(8)
    params, state = init_fn(next(key), x, y)
    out, _ = apply_fn(params, state, next(key), x, y)
    loss, metrics = out

    logging.info('loss: %g', loss)
    logging.info('metrics: %r', metrics)

    param_count = tree_size(params)
    self.assertEqual(param_count, 58_704_438)
Exemple #29
0
    def target_func(self, target_params, target_state, rng, transition_batch):
        rngs = hk.PRNGSequence(rng)

        # compute q-values
        params, state = target_params['q_targ'], target_state['q_targ']
        S_next = self.q_targ.observation_preprocessor(next(rngs),
                                                      transition_batch.S_next)
        Q_s_next, _ = self.q_targ.function_type2(params, state, next(rngs),
                                                 S_next, False)

        # action propensities
        params, state = target_params['pi_targ'], target_state['pi_targ']
        S_next = self.pi_targ.observation_preprocessor(next(rngs),
                                                       transition_batch.S_next)
        dist_params, _ = self.pi_targ.function(params, state, next(rngs),
                                               S_next, False)
        P = jax.nn.softmax(dist_params['logits'], axis=-1)

        # project
        assert P.ndim == 2, f"bad shape: {P.shape}"
        assert Q_s_next.ndim == 2, f"bad shape: {Q_s_next.shape}"
        Q_sa_next = jax.vmap(jnp.dot)(P, Q_s_next)

        f, f_inv = self.q.value_transform.transform_func, self.q_targ.value_transform.inverse_func
        return f(transition_batch.Rn + transition_batch.In * f_inv(Q_sa_next))
Exemple #30
0
  def test_transformer_with_extra_runs(self):
    extra = np.array([[1, 1, 0, 0],
                      [2, 2, 2, 2],
                      [3, 3, 3, 0]], dtype=np.int32)
    seqs = np.array([[1, 2, 3, 0, 0],
                     [2, 4, 5, 6, 0],
                     [3, 3, 5, 1, 2]], dtype=np.int32)
    x = seqs[:, :-1]
    y = seqs[:, 1:]
    vocab_size = seqs.max() + 1
    extra_vocab_size = extra.max() + 1

    def forward(inputs, labels, extra):
      input_mask = (labels != 0).astype(jnp.float32)
      extra_mask = (extra != 0).astype(jnp.float32)
      extra = hk.Embed(vocab_size=extra_vocab_size, embed_dim=16)(extra)
      model = models.TransformerXL(
          vocab_size=vocab_size,
          emb_dim=16,
          num_layers=2,
          num_heads=4,
          cutoffs=[],
      )
      return model.loss(inputs, labels, mask=input_mask,
                        extra=extra, extra_mask=extra_mask)

    init_fn, apply_fn = hk.transform_with_state(forward)
    key = hk.PRNGSequence(8)
    params, state = init_fn(next(key), x, y, extra)
    out, _ = apply_fn(params, state, next(key), x, y, extra)
    loss, metrics = out

    logging.info('loss: %g', loss)
    logging.info('metrics: %r', metrics)