예제 #1
0
    def compile(self):
        """Compile the update function and validation function.

    The update function is used to update parameters of the network.
    It first calls the loss function, compute the gradient,
    and finally calls the optimizer.
    """
        _train_loss_fn = hk.transform_with_state(self.train_loss_fn)
        _vag = jax.value_and_grad(_train_loss_fn.apply, has_aux=True)
        rng = jax.random.PRNGKey(42)
        params, aux = _train_loss_fn.init(rng, next(self.train_iter))
        optim = self.optimizer.init(params)
        self.state: TrainerState = TrainerState(params, aux, optim, rng)

        def _update_ops(state: TrainerState, inputs):
            rng, rng_next = jax.random.split(state.rng, 2)
            (loss, aux), grads = _vag(state.params, state.aux, rng, inputs)
            grads, optim = self.optimizer.update(grads, state.optim,
                                                 state.params)
            params = optax.apply_updates(state.params, grads)
            return TrainerState(params, aux, optim, rng_next), loss

        self.jit_update_ops = jax.jit(_update_ops)

        if self.val_loss_fn is not None:
            _val_loss_obj = hk.transform_with_state(self.val_loss_fn)

            def _val_loss_fn(inputs, state: TrainerState):
                return _val_loss_obj.apply(state.params, state.aux, state.rng,
                                           inputs)[0]

            self.jit_val_loss_fn = jax.jit(_val_loss_fn)
예제 #2
0
    def test_vmap(self, module_fn: ModuleFn, shape, dtype):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        # Expand our input since we will map over it.
        x = jnp.broadcast_to(x, (2, ) + x.shape)

        f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda
        f_mapped = hk.transform_with_state(
            lambda x: hk.vmap(lambda x: module_fn()(x))(x))  # pylint: disable=unnecessary-lambda

        params, state = f_mapped.init(rng, x)

        # JAX vmap with explicitly unmapped params/state/rng. This should be
        # equivalent to `f_mapped.apply(..)` (since by default hk.vmap does not map
        # params/state/rng).
        v_apply = jax.vmap(f.apply,
                           in_axes=(None, None, None, 0),
                           out_axes=(0, None))

        module_type = descriptors.module_type(module_fn)
        atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)
        jax.tree_multimap(assert_allclose,
                          f_mapped.apply(params, state, rng, x),
                          v_apply(params, state, rng, x))
예제 #3
0
    def test_logits_config(self, resnet_class, resnet_v2):
        def model_func_logits_config_default(img):
            model = resnet_class(1000, resnet_v2=resnet_v2)
            return model(img, is_training=True)

        def model_func_logits_config_modified(img):
            model = resnet_class(1000,
                                 resnet_v2=resnet_v2,
                                 logits_config=dict(w_init=jnp.ones))
            return model(img, is_training=True)

        image = jnp.ones([2, 64, 64, 3])
        rng = jax.random.PRNGKey(0)

        model = hk.transform_with_state(model_func_logits_config_default)
        params, _ = model.init(rng, image)
        logits_keys = [k for k in params.keys() if "/logits" in k]
        self.assertLen(logits_keys, 1)

        # Check logits params are zeros
        w_logits = params[logits_keys[0]]["w"]
        np.testing.assert_allclose(jnp.zeros_like(w_logits), w_logits)

        model = hk.transform_with_state(model_func_logits_config_modified)
        params, _ = model.init(rng, image)

        # Check logits params are ones
        w_logits = params[logits_keys[0]]["w"]
        np.testing.assert_allclose(jnp.ones_like(w_logits), w_logits)
예제 #4
0
    def setUp(self):
        super(GatedLinearNetworkTest, self).setUp()
        self._name = "test_network"
        self._rng = hk.PRNGSequence(jax.random.PRNGKey(42))

        self._output_sizes = (4, 5, 6)
        self._context_dim = 2
        self._bias_len = 3

        def gln_factory():
            return gaussian.GatedLinearNetwork(
                output_sizes=self._output_sizes,
                context_dim=self._context_dim,
                bias_len=self._bias_len,
                name=self._name,
            )

        def inference_fn(inputs, side_info):
            return gln_factory().inference(inputs, side_info, 0.5)

        def batch_inference_fn(inputs, side_info):
            return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info)

        def update_fn(inputs, side_info, label, learning_rate):
            params, predictions, unused_loss = gln_factory().update(
                inputs, side_info, label, learning_rate, 0.5)
            return predictions, params

        def batch_update_fn(inputs, side_info, label, learning_rate):
            predictions, params = jax.vmap(update_fn,
                                           in_axes=(0, 0, 0,
                                                    None))(inputs, side_info,
                                                           label,
                                                           learning_rate)
            avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0),
                                            params)
            return predictions, avg_params

        # Haiku transform functions.
        self._init_fn, inference_fn_ = hk.without_apply_rng(
            hk.transform_with_state(inference_fn))
        self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng(
            hk.transform_with_state(batch_inference_fn))
        _, update_fn_ = hk.without_apply_rng(
            hk.transform_with_state(update_fn))
        _, batch_update_fn_ = hk.without_apply_rng(
            hk.transform_with_state(batch_update_fn))

        self._inference_fn = jax.jit(inference_fn_)
        self._batch_inference_fn = jax.jit(batch_inference_fn_)
        self._update_fn = jax.jit(update_fn_)
        self._batch_update_fn = jax.jit(batch_update_fn_)
예제 #5
0
    def __init__(
        self,
        mode: str,
        init_rng: jnp.ndarray,
        config: config_dict.ConfigDict,
    ):
        """Initializes experiment."""
        super(Experiment, self).__init__(mode=mode, init_rng=init_rng)
        tf.config.experimental.set_visible_devices([], device_type='GPU')
        tf.config.experimental.set_visible_devices([], device_type='TPU')

        if mode not in ('train', 'eval', 'train_eval_multithreaded'):
            raise ValueError(f'Invalid mode {mode}.')

        self.mode = mode
        self.config = config
        self.init_rng = init_rng
        self.forward = hk.transform_with_state(self._forward_fn)

        self._predictions = None

        # Needed for checkpoint restore.
        self._params = None
        self._ema_params = None
        self._network_state = None
        self._ema_network_state = None
        self._opt_state = None

        # Track what has started.
        self._training = False
        self._evaluating = False
예제 #6
0
파일: model.py 프로젝트: lan496/jax-xtal
def get_model_fn_t(
    num_initial_atom_features: int,
    num_atom_features: int,
    num_bond_features: int,
    num_convs: int,
    num_hidden_layers: int,
    num_hidden_features: int,
    max_num_neighbors: int,
    batch_size: int,
):
    def model_fn(batch: Batch, is_training: bool) -> jnp.ndarray:
        model = CGCNN(
            num_initial_atom_features=num_initial_atom_features,
            num_atom_features=num_atom_features,
            num_bond_features=num_bond_features,
            num_convs=num_convs,
            num_hidden_layers=num_hidden_layers,
            num_hidden_features=num_hidden_features,
            max_num_neighbors=max_num_neighbors,
            batch_size=batch_size,
            name="cgcnn",
        )
        neighbor_indices = batch["neighbor_indices"]
        atom_features = batch["atom_features"]
        bond_features = batch["bond_features"]
        num_atoms = batch["num_atoms"]
        segment_ids = batch["segment_ids"]
        return model(neighbor_indices, atom_features, bond_features, num_atoms,
                     segment_ids, is_training)

    return hk.transform_with_state(model_fn)
예제 #7
0
def get_model(model_name, data_info, **kwargs):
    _MODEL_FNS = {
        "lenet":
        make_lenet5_fn,
        "resnet20":
        make_resnet20_fn,
        "resnet20_frn":
        make_resnet20_frn_fn,
        "resnet20_frn_swish":
        functools.partial(make_resnet20_frn_fn, activation=jax.nn.swish),
        "cnn_lstm":
        make_cnn_lstm,
        "smooth_cnn_lstm":
        make_smooth_cnn_lstm,
        "mlp_regression":
        make_mlp_regression,
        "mlp_regression_small":
        make_mlp_regression_small,
        "mlp_classification":
        make_mlp_classification,
        "logistic_regression":
        make_logistic_regression,
    }
    net_fn = _MODEL_FNS[model_name](data_info, **kwargs)
    net = hk.transform_with_state(net_fn)
    return net.apply, net.init
예제 #8
0
    def test_numpy_and_jax_results_close(self, module_fn: ModuleFn, shape,
                                         dtype):

        f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda

        rng = jax.random.PRNGKey(42)
        x = jnp.ones(shape, dtype)
        params, state = f.init(rng, x)
        out, new_state = f.apply(params, state, rng, x)

        np_rng = np.asarray(rng)
        np_x = np.asarray(x)

        with self.subTest('init'):
            params2, state2 = f.init(np_rng, np_x)
            jax.tree_multimap(np.testing.assert_allclose, params, params2)
            jax.tree_multimap(np.testing.assert_allclose, state, state2)

        with self.subTest('apply'):
            np_params = jax.tree_map(np.asarray, params)
            np_state = jax.tree_map(np.asarray, state)
            out2, new_state2 = f.apply(np_params, np_state, np_rng, np_x)
            jax.tree_multimap(np.testing.assert_allclose, out, out2)
            jax.tree_multimap(np.testing.assert_allclose, new_state,
                              new_state2)
예제 #9
0
 def test_forward_shape(self):
     """Test output shape of SparseGCNPredicator"""
     forward = hk.transform_with_state(self.__forward)
     params, state = forward.init(next(self.key), *self.input_data)
     preds, _ = forward.apply(params, state, next(self.key),
                              *self.input_data)
     assert preds.shape == (batch_size, n_out)
예제 #10
0
 def test_forward_shape(self):
     """Test output shape of PadGCN"""
     forward = hk.transform_with_state(self.__forward)
     params, state = forward.init(next(self.key), *self.input_data)
     preds, _ = forward.apply(params, state, next(self.key),
                              *self.input_data)
     assert preds.shape == (batch_size, max_node_size, hidden_feats[-1])
예제 #11
0
    def __init__(self, mode, init_rng, config):
        """Initializes experiment."""

        super(Experiment, self).__init__(mode=mode, init_rng=init_rng)

        self.mode = mode
        self.init_rng = init_rng
        self.config = config

        # Checkpointed experiment state.
        self._params = None
        self._state = None
        self._opt_state = None

        # Input pipelines.
        self._train_input = None
        self._eval_input = None

        self.forward = hk.transform_with_state(self._forward_fn)

        # NOTE: We "donate" the `params, state, opt_state` arguments which allows
        # JAX (on some backends) to reuse the device memory associated with these
        # inputs to store the outputs of our function (which also start with
        # `params, state, opt_state`).
        self._update_func = jax.pmap(self._update_func,
                                     axis_name='i',
                                     donate_argnums=(0, 1, 2))
        self._eval_batch = jax.jit(self._eval_batch)
예제 #12
0
    def __init__(self, func, observation_space, action_space=None, random_seed=None):

        if not isinstance(observation_space, Space):
            raise TypeError(
                f"observation_space must be derived from gym.Space, got: {type(observation_space)}")
        self.observation_space = observation_space

        if action_space is not None:
            if not isinstance(action_space, Space):
                raise TypeError(
                    f"action_space must be derived from gym.Space, got: {type(action_space)}")
            self.action_space = action_space

        self.random_seed = random_seed  # also initializes self.rng via RandomStateMixin
        self._jitted_funcs = {}

        # Haiku-transform the provided func
        example_data = self._check_signature(func)
        static_argnums = tuple(i + 3 for i in example_data.inputs.static_argnums)
        transformed = hk.transform_with_state(func)
        self._function = jit(transformed.apply, static_argnums=static_argnums)

        # init function params and state
        self._params, self._function_state = transformed.init(self.rng, *example_data.inputs.args)

        # check if output has the expected shape etc.
        output, _ = \
            self._function(self.params, self.function_state, self.rng, *example_data.inputs.args)
        self._check_output(output, example_data.output)

        def soft_update_func(old, new, tau):
            return jax.tree_multimap(lambda a, b: (1 - tau) * a + tau * b, old, new)

        self._soft_update_func = jit(soft_update_func)
예제 #13
0
  def test_graph_conditioned_transformer_learns(self):
    graphs = jraph.GraphsTuple(
        nodes=np.ones((4, 3), dtype=np.float32),
        edges=np.ones((3, 1), dtype=np.float32),
        senders=np.array([0, 2, 3], dtype=np.int32),
        receivers=np.array([1, 3, 2], dtype=np.int32),
        n_node=np.array([2, 2], dtype=np.int32),
        n_edge=np.array([1, 2], dtype=np.int32),
        globals=None,
        )
    seqs = np.array([[1, 2, 2, 0],
                     [1, 3, 3, 3]], dtype=np.int32)
    vocab_size = seqs.max() + 1
    embed_dim = 8
    max_graph_size = graphs.n_node.max()

    logging.info('Training seqs: %r', seqs)

    x = seqs[:, :-1]
    y = seqs[:, 1:]

    def model_fn(vocab_size, embed_dim):
      return models.Graph2TextTransformer(
          vocab_size=vocab_size,
          emb_dim=embed_dim,
          num_layers=2,
          num_heads=4,
          cutoffs=[],
          gnn_embed_dim=embed_dim,
          gnn_num_layers=2)

    def forward(graphs, inputs, labels, max_graph_size):
      input_mask = (labels != 0).astype(jnp.float32)
      return model_fn(vocab_size, embed_dim).loss(
          graphs, max_graph_size, False, inputs, labels, mask=input_mask)

    init_fn, apply_fn = hk.transform_with_state(forward)
    rng = hk.PRNGSequence(8)
    params, state = init_fn(next(rng), graphs, x, y, max_graph_size)

    def apply(*args, **kwargs):
      out, state = apply_fn(*args, **kwargs)
      return out[0], (out[1], state)
    apply = jax.jit(apply, static_argnums=6)

    optimizer = optax.chain(
        optax.scale_by_adam(),
        optax.scale(-1e-3))
    opt_state = optimizer.init(params)
    for i in range(500):
      (loss, model_state), grad = jax.value_and_grad(apply, has_aux=True)(
          params, state, next(rng), graphs, x, y, max_graph_size)
      metrics, state = model_state
      updates, opt_state = optimizer.update(grad, opt_state, params)
      params = optax.apply_updates(params, updates)
      if (i + 1) % 100 == 0:
        logging.info(
            'Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()})
    logging.info('Loss: %.8f', loss)
    self.assertLess(loss, 1.0)
예제 #14
0
  def test_bow_transformer_runs(self):
    bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1],
                    [0, 1, 0, 0, 1, 0, 1, 0],
                    [1, 0, 0, 0, 1, 0, 0, 1]], dtype=np.int32)
    seqs = np.array([[1, 2, 3, 0, 0],
                     [2, 4, 5, 6, 0],
                     [3, 3, 5, 1, 2]], dtype=np.int32)
    x = seqs[:, :-1]
    y = seqs[:, 1:]
    vocab_size = seqs.max() + 1

    def forward(bow, inputs, labels):
      model = models.Bow2TextTransformer(
          vocab_size=vocab_size,
          emb_dim=16,
          num_layers=2,
          num_heads=4,
          cutoffs=[])
      return model.loss(bow, inputs, labels)

    init_fn, apply_fn = hk.transform_with_state(forward)
    key = hk.PRNGSequence(8)
    params, state = init_fn(next(key), bow, x, y)
    out, _ = apply_fn(params, state, next(key), bow, x, y)
    loss, metrics = out

    logging.info('loss: %g', loss)
    logging.info('metrics: %r', metrics)
예제 #15
0
    def test_vmap(
        self,
        module_fn: ModuleFn,
        shape: Shape,
        dtype: DType,
    ):
        batch_size, shape = shape[0], shape[1:]
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            sample = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            sample = jax.random.uniform(rng, shape, dtype)
        batch = jnp.broadcast_to(sample, (batch_size, ) + sample.shape)

        def g(x):
            return module_fn()(x)

        f = hk.transform_with_state(g)

        # Ensure application under vmap is the same.
        params, state = f.init(rng, sample)
        v_apply = jax.vmap(f.apply, in_axes=(None, None, None, 0))
        jax.tree_multimap(
            lambda a, b: np.testing.assert_allclose(a, b, atol=DEFAULT_ATOL),
            f.apply(params, state, rng, batch),
            v_apply(params, state, rng, batch))
예제 #16
0
    def test_jit(
        self,
        module_fn: ModuleFn,
        shape: Shape,
        dtype: DType,
    ):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def g(x):
            return module_fn()(x)

        f = hk.transform_with_state(g)

        atol = CUSTOM_ATOL.get(module_type(module_fn), DEFAULT_ATOL)
        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=atol)

        # Ensure initialization under jit is the same.
        jax.tree_multimap(assert_allclose, f.init(rng, x),
                          jax.jit(f.init)(rng, x))

        # Ensure application under jit is the same.
        params, state = f.init(rng, x)
        jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x),
                          jax.jit(f.apply)(params, state, rng, x))
예제 #17
0
  def test_hk_remat(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, remat=False):
      mod = module_fn()
      if remat:
        mod = hk.remat(mod)
      out = mod(x)
      if isinstance(out, dict):
        out = out['loss']
      return jnp.mean(out)

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

    grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True)
    grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True),
                             has_aux=True)

    params, state = f.init(rng, x)
    jax.tree_multimap(assert_allclose,
                      grad_jax_remat(params, state, rng, x),
                      grad_hk_remat(params, state, rng, x))
예제 #18
0
  def test_optimize_rng_use_under_jit(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x):
      return module_fn()(x)

    f = hk.transform_with_state(hk.experimental.optimize_rng_use(g))

    module_type = descriptors.module_type(module_fn)
    atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL)
    assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol)

    params, state = jax.jit(f.init)(rng, x)
    jax.tree_multimap(assert_allclose, (params, state), f.init(rng, x))

    if module_type in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA):
      # For stochastic modules just test apply runs.
      jax.device_get(jax.jit(f.apply)(params, state, rng, x))

    else:
      jax.tree_multimap(assert_allclose,
                        jax.jit(f.apply)(params, state, rng, x),
                        f.apply(params, state, rng, x))
예제 #19
0
  def test_hk_jit(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
      init: bool,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, jit=False):
      mod = module_fn()
      if jit:
        mod = hk.jit(mod)
      return mod(x)

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-4)

    # NOTE: We shard init/apply tests since some modules are expensive to jit
    # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test).
    if init:
      jax.tree_multimap(assert_allclose,
                        jax.jit(f.init)(rng, x),
                        f.init(rng, x, jit=True))

    else:
      params, state = f.init(rng, x)
      jax.tree_multimap(assert_allclose,
                        jax.jit(f.apply)(params, state, rng, x),
                        f.apply(params, state, rng, x, jit=True))
예제 #20
0
    def __init__(
        self,
        output_dim: int,
        activation_fn: str = "relu",
        stochastic_parameters: bool = False,
        linear_model: bool = False,
        dropout: bool = False,
        dropout_rate: float = 0.0,
    ):
        """Wrapper of resnet50_fsvi

    Args:
    output_dim: the output dimension
    activation_fn: the type of activation function, e.g. "relu", "tanh"
    stochastic_parameters: if True, we keep a variational distribution of
      parameters.
    linear_model: if True, only put variational distribution on the last layer.
    dropout: if True, apply dropout.
    dropout_rate: dropout rate if we apply dropout.
    """
        self.output_dim = output_dim
        self.linear_model = linear_model
        self.dropout = dropout
        self.dropout_rate = dropout_rate
        self.activation_fn = ACTIVATION_DICT[activation_fn]
        self.stochastic_parameters = stochastic_parameters

        self.forward = hk.transform_with_state(self.make_forward_fn())
예제 #21
0
  def test_profiler_name_scopes(
      self,
      module_fn: descriptors.ModuleFn,
      shape: Shape,
      dtype: DType,
  ):
    rng = jax.random.PRNGKey(42)
    if jnp.issubdtype(dtype, jnp.integer):
      x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
    else:
      x = jax.random.uniform(rng, shape, dtype)

    def g(x, name_scopes=False):
      hk.experimental.profiler_name_scopes(enabled=name_scopes)
      mod = module_fn()
      return mod(x)

    f = hk.transform_with_state(g)

    assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

    params, state = f.init(rng, x)
    jax.tree_multimap(assert_allclose,
                      f.apply(params, state, rng, x),
                      f.apply(params, state, rng, x, name_scopes=True))

    # TODO(lenamartens): flip to True when default changes
    hk.experimental.profiler_name_scopes(enabled=False)
예제 #22
0
 def test_forward_shape_with_batch_norm(self):
     """Test output shape of PadGCNLayer with BatchNorm"""
     forward = hk.transform_with_state(self.__forward_with_batch_norm)
     params, state = forward.init(next(self.key), *self.input_data)
     preds, _ = forward.apply(params, state, next(self.key),
                              *self.input_data)
     assert preds.shape == (batch_size, max_node_size, out_feats)
예제 #23
0
 def test_abstract_to_dot(self, module_fn: ModuleFn, shape, dtype):
   f = hk.transform_with_state(lambda x: module_fn()(x))  # pylint: disable=unnecessary-lambda
   rng = jax.random.PRNGKey(42)
   x = np.ones(shape, dtype)
   params, state = jax.eval_shape(f.init, rng, x)
   self.assertIsNotNone(
       hk.experimental.abstract_to_dot(f.apply)(params, state, rng, x))
예제 #24
0
  def test_transformer_with_extra_runs(self):
    extra = np.array([[1, 1, 0, 0],
                      [2, 2, 2, 2],
                      [3, 3, 3, 0]], dtype=np.int32)
    seqs = np.array([[1, 2, 3, 0, 0],
                     [2, 4, 5, 6, 0],
                     [3, 3, 5, 1, 2]], dtype=np.int32)
    x = seqs[:, :-1]
    y = seqs[:, 1:]
    vocab_size = seqs.max() + 1
    extra_vocab_size = extra.max() + 1

    def forward(inputs, labels, extra):
      input_mask = (labels != 0).astype(jnp.float32)
      extra_mask = (extra != 0).astype(jnp.float32)
      extra = hk.Embed(vocab_size=extra_vocab_size, embed_dim=16)(extra)
      model = models.TransformerXL(
          vocab_size=vocab_size,
          emb_dim=16,
          num_layers=2,
          num_heads=4,
          cutoffs=[],
      )
      return model.loss(inputs, labels, mask=input_mask,
                        extra=extra, extra_mask=extra_mask)

    init_fn, apply_fn = hk.transform_with_state(forward)
    key = hk.PRNGSequence(8)
    params, state = init_fn(next(key), x, y, extra)
    out, _ = apply_fn(params, state, next(key), x, y, extra)
    loss, metrics = out

    logging.info('loss: %g', loss)
    logging.info('metrics: %r', metrics)
예제 #25
0
  def test_transformer_param_count(self):
    seqs = np.array([[1, 2, 3, 0, 0],
                     [3, 3, 5, 1, 2]], dtype=np.int32)
    x = seqs[:, :-1]
    y = seqs[:, 1:]
    vocab_size = 267_735

    def forward(inputs, labels):
      input_mask = (labels != 0).astype(jnp.float32)
      model = models.TransformerXL(
          vocab_size=vocab_size,
          emb_dim=210,
          num_layers=2,
          num_heads=10,
          dropout_prob=0.0,
          dropout_attn_prob=0.0,
          self_att_init_scale=0.02,
          dense_init_scale=0.02,
          dense_dim=2100,
          cutoffs=(20000, 40000, 200000),  # WikiText-103
          relative_pos_clamp_len=None,
      )
      return model.loss(inputs, labels, mask=input_mask, cache_steps=2)

    init_fn, apply_fn = hk.transform_with_state(forward)
    key = hk.PRNGSequence(8)
    params, state = init_fn(next(key), x, y)
    out, _ = apply_fn(params, state, next(key), x, y)
    loss, metrics = out

    logging.info('loss: %g', loss)
    logging.info('metrics: %r', metrics)

    param_count = tree_size(params)
    self.assertEqual(param_count, 58_704_438)
예제 #26
0
    def __init__(self, loss_fn, optimizer, devices=None, has_graph=False):
        self._net_init_fn, self._apply_fn = hk.transform_with_state(
            functools.partial(loss_fn, is_training=True))
        _, self._eval_apply_fn = hk.transform_with_state(
            functools.partial(loss_fn, is_training=False))

        if optimizer is None:
            optimizer = optax.identity()
        self._optimizer = optimizer

        self._num_devices = jax.local_device_count()
        if devices is None:
            devices = []
            for host_id in range(jax.process_count()):
                for device_id in jax.local_devices(host_id):
                    devices.append(device_id)
        else:
            self._num_devices = min(self._num_devices, len(devices))

        def _pmap(f, static_broadcasted_argnums=()):
            return jax.pmap(
                f,
                axis_name='i',
                devices=devices,
                static_broadcasted_argnums=static_broadcasted_argnums)

        def handle_graph_size(fn):
            def _fn(*args):
                batch = args[-1].copy()
                max_graph_size = batch['max_graph_size']
                del batch['max_graph_size']
                args = args[:-1] + (batch, max_graph_size)
                return fn(*args)

            return _fn

        # Try to jit.
        if has_graph:
            # If the model contains full graphs, we need to set the max_graph_size
            # as a statically broadcasted argument.
            self._init_fn = handle_graph_size(_pmap(self._init, 4))
            self._update_fn = handle_graph_size(_pmap(self._update, 2))
            self._eval_fn = handle_graph_size(_pmap(self._eval, 2))
        else:
            self._init_fn = _pmap(self._init)
            self._update_fn = _pmap(self._update)
            self._eval_fn = _pmap(self._eval)
예제 #27
0
 def __init__(self, *args, **kwargs):
     self._model = hk.transform_with_state(
         lambda *a, **k:  # pylint: disable=g-long-lambda,unnecessary-lambda
         s3d.S3D(normalize_fn=normalization.get_normalize_fn(),
                 *args,
                 **kwargs)(*a, **k))
     self._rng = jax.random.PRNGKey(42)
     self._params, self._state = None, None
예제 #28
0
    def test_hk_scan(self, module_fn: descriptors.ModuleFn, shape, dtype,
                     init):
        rng = jax.random.PRNGKey(42)
        if jnp.issubdtype(dtype, jnp.integer):
            x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype)
        else:
            x = jax.random.uniform(rng, shape, dtype)

        def f(x):
            mod = module_fn()
            return mod(x)

        def u_f(xs):
            mod = module_fn()

            def s(carry, x):
                y = mod(x)
                return carry, y

            _, ys = hk.scan(s, (), xs)
            return ys

        u_f = hk.transform_with_state(u_f)
        f = hk.transform_with_state(f)

        assert_allclose = functools.partial(np.testing.assert_allclose,
                                            atol=1e-4)
        xs = jnp.broadcast_to(x, (8, ) + x.shape)
        params, state = f.init(rng, x)

        if init:
            u_params, u_state = u_f.init(rng, xs)
            jax.tree_multimap(assert_allclose, u_params, params)
            jax.tree_multimap(assert_allclose, u_state, state)
            return

        def fun(state, x):
            y, state = f.apply(params, state, rng, x)
            return state, y

        s_state, s_ys = jax.lax.scan(fun, state, xs)
        u_ys, u_state = u_f.apply(params, state, rng, xs)

        jax.tree_multimap(assert_allclose, u_ys, s_ys)
        jax.tree_multimap(assert_allclose, u_state, s_state)
예제 #29
0
    def __init__(self, random_seed, num_classes, batch_size, max_steps,
                 enable_double_transpose, checkpoint_to_evaluate,
                 allow_train_from_scratch, freeze_backbone, network_config,
                 optimizer_config, lr_schedule_config, evaluation_config,
                 checkpointing_config):
        """Constructs the experiment.

    Args:
      random_seed: the random seed to use when initializing network weights.
      num_classes: the number of classes; used for the online evaluation.
      batch_size: the total batch size; should be a multiple of the number of
        available accelerators.
      max_steps: the number of training steps; used for the lr/target network
        ema schedules.
      enable_double_transpose: see dataset.py; only has effect on TPU.
      checkpoint_to_evaluate: the path to the checkpoint to evaluate.
      allow_train_from_scratch: whether to allow training without specifying a
        checkpoint to evaluate (training from scratch).
      freeze_backbone: whether the backbone resnet should remain frozen (linear
        evaluation) or be trainable (fine-tuning).
      network_config: the configuration for the network.
      optimizer_config: the configuration for the optimizer.
      lr_schedule_config: the configuration for the learning rate schedule.
      evaluation_config: the evaluation configuration.
      checkpointing_config: the configuration for checkpointing.
    """

        self._random_seed = random_seed
        self._enable_double_transpose = enable_double_transpose
        self._num_classes = num_classes
        self._lr_schedule_config = lr_schedule_config
        self._batch_size = batch_size
        self._max_steps = max_steps
        self._checkpoint_to_evaluate = checkpoint_to_evaluate
        self._allow_train_from_scratch = allow_train_from_scratch
        self._freeze_backbone = freeze_backbone
        self._optimizer_config = optimizer_config
        self._evaluation_config = evaluation_config

        # Checkpointed experiment state.
        self._experiment_state = None

        # Input pipelines.
        self._train_input = None
        self._eval_input = None

        backbone_fn = functools.partial(self._backbone_fn, **network_config)
        self.forward_backbone = hk.without_apply_rng(
            hk.transform_with_state(backbone_fn))
        self.forward_classif = hk.without_apply_rng(
            hk.transform(self._classif_fn))
        self.update_pmap = jax.pmap(self._update_func, axis_name='i')
        self.eval_batch_jit = jax.jit(self._eval_batch)

        self._is_backbone_training = not self._freeze_backbone

        self._checkpointer = checkpointing.Checkpointer(**checkpointing_config)
예제 #30
0
 def test_info_and_html(self, module_fn: ModuleFn, shape, dtype):
   x = jnp.ones(shape, dtype)
   f = hk.transform_with_state(lambda: module_fn()(x))
   rng = jax.random.PRNGKey(42)
   params, state = f.init(rng)
   info = jaxpr_info.make_model_info(f.apply)(params, state, rng)
   if descriptors.module_type(module_fn).__name__ != 'Sequential':
     self.assertNotEmpty(info.expressions)
   self.assertIsNotNone(jaxpr_info.as_html_page(info))