Exemplo n.º 1
0
    def forward_with_state(self, inputs, weights, state, rng=None):
        """Computes this layer's output as part of a forward pass through the model.

    Args:
      inputs: Layer inputs (subclasses may use different inputs)
      weights: Layer weights
      state: Complete state of the layer
      rng: PRNG key

    Returns:
      A tuple (output, new_state).
    """
        if not self.use_reference_code:
            # By default, an efficient, batched implementation is used.
            output, new_state, _, _ = self.forward_and_or_backward(
                inputs, weights, state, compute_output=True, update_state=True)
            return output, new_state

        # The reference implementation below provides a more readable overview of
        # what this class does. It's not optimized, however, and should only be used
        # when testing this class for correctness.
        if not isinstance(inputs, (tuple, list)):
            inputs = (inputs, )
        batch_size = int(inputs[0].shape[0])
        seqlen = inputs[0].shape[-2]
        d_model = inputs[0].shape[-1]
        output_accum = [np.zeros((seqlen, d_model)) for _ in range(batch_size)]
        new_state = []
        for example_idx in range(batch_size):
            for head_idx in range(self.n_heads):
                # pylint: disable=cell-var-from-loop
                single_inputs = jax.tree_map(lambda x: x[example_idx], inputs)
                single_weights = jax.tree_map(lambda w: w[head_idx], weights)
                single_state = jax.tree_map(
                    lambda s: s[example_idx * self.n_heads + head_idx], state)
                # pylint: enable=cell-var-from-loop
                single_out, single_new_state = self.forward_unbatched(
                    *single_inputs,
                    weights=single_weights,
                    state=single_state,
                    update_state=True)
                new_state.append(single_new_state)
                output_accum[
                    example_idx] = output_accum[example_idx] + single_out

        output = np.stack(output_accum, 0)
        if new_state and jax.tree_leaves(new_state[0]):
            new_state = jax.tree_multimap(lambda *s: np.stack(s, 0),
                                          *new_state)
        else:
            new_state = state
        return output, new_state
Exemplo n.º 2
0
    def new_weights_and_state(self, input_signature):
        input_signature_unbatched = jax.tree_map(
            lambda x: type(x)(shape=x.shape[1:], dtype=x.dtype),
            input_signature)
        if isinstance(input_signature, (tuple, list)):
            batch_size = int(input_signature[0].shape[0])
        else:
            batch_size = int(input_signature.shape[0])

        weights = []
        weight_rngs = self.new_rngs(self.n_heads)
        for i in range(self.n_heads):
            weights.append(
                self.create_weights_unbatched(input_signature_unbatched,
                                              weight_rngs[i]))
        state = []
        state_rngs = self.new_rngs(self.n_heads * batch_size)
        for i in range(self.n_heads * batch_size):
            state.append(
                self.create_state_unbatched(input_signature_unbatched,
                                            state_rngs[i]))

        stack_along_axis_0 = lambda *x: np.stack(x, axis=0)
        weights = jax.tree_multimap(stack_along_axis_0, *weights)
        state = jax.tree_multimap(stack_along_axis_0, *state)
        return weights, state
Exemplo n.º 3
0
 def forward_with_state(self, inputs, weights=base.EMPTY_WEIGHTS,
                        state=base.EMPTY_STATE, rng=None):
   if self._mode in ('train', 'eval'):
     x = inputs
     symbol_size = jnp.shape(x)[1]
     px = weights[:, :symbol_size, :]
     if self._dropout == 0:
       return (x + px, state)
     else:
       noise_shape = list(px.shape)
       for dim in self._dropout_broadcast_dims:
         noise_shape[dim] = 1
       keep_prob = 1.0 - self._dropout
       if math.backend_name() == 'jax':
         keep_prob = jax.lax.tie_in(x, jnp.full((), keep_prob, dtype=x.dtype))
       keep = math.random.bernoulli(rng, keep_prob, tuple(noise_shape))
       multiplier = keep.astype(x.dtype) / keep_prob
       return (x + px * multiplier, state)
   else:
     assert self._mode == 'predict'
     assert self._dropout == 0
     # State in this class is only used for fast inference. In that case,
     # the model is called with consecutive elements position-by-position.
     # This positional encoding layer needs to store the index of the current
     # position then and increment it on each call -- that's how state is used
     # and updated below.
     if inputs.shape[1] == 1:
       return (inputs + jnp.expand_dims(weights[0, state, :], 1), state + 1)
     else:
       emb = []
       for i in range(inputs.shape[0]):
         emb.append(jax.lax.dynamic_slice_in_dim(
             weights[0], state[i], inputs.shape[1], axis=0))
       return inputs + jnp.stack(emb, 0), state + inputs.shape[1]
Exemplo n.º 4
0
 def predict(x, weights, state, rng):
   """Predict function jited and parallelized as requested."""
   res, state = _combine_devices(model_predict(
       reshape_by_device(x, n_devices),
       weights,
       state,
       np.stack(math.random.split(rng, n_devices))))
   return math.nested_map(lambda y: np.mean(y, axis=0), res), state
Exemplo n.º 5
0
 def predict(x, weights, state, rng):
     """Predict function JIT-compileds and parallelized as requested."""
     res, state = _combine_devices(
         model_predict(reshape_by_device(x, n_devices), weights, state,
                       jnp.stack(math.random.split(rng, n_devices))))
     if do_mean:
         return math.nested_map(lambda y: jnp.mean(y, axis=0), res), state
     else:
         return res, state
Exemplo n.º 6
0
    def forward(self, inputs):
        if self._mode != 'predict':
            x = inputs
            symbol_size = jnp.shape(x)[1]
            px = self.weights[:, :symbol_size, :]
            if self._dropout == 0:
                return x + px
            else:
                noise_shape = list(px.shape)
                for dim in self._dropout_broadcast_dims:
                    noise_shape[dim] = 1
                keep_prob = 1.0 - self._dropout
                if math.backend_name() == 'jax':
                    keep_prob = jax.lax.tie_in(
                        x, jnp.full((), keep_prob, dtype=x.dtype))
                keep = math.random.bernoulli(self.rng, keep_prob,
                                             tuple(noise_shape))
                multiplier = keep.astype(x.dtype) / keep_prob
                return x + px * multiplier
        else:
            if self._dropout != 0:
                raise ValueError(f'In predict mode, but dropout rate '
                                 f'({self._dropout}) is not zero.')

            # State in this class is only used for fast inference. In that case,
            # the model is called with consecutive elements position-by-position.
            # This positional encoding layer needs to store the index of the current
            # position then and increment it on each call -- that's how state is used
            # and updated below.
            state = self.state
            if inputs.shape[1] == 1:
                self.state = state + 1
                return inputs + jnp.expand_dims(self.weights[0, state, :], 1)
            else:
                emb = []
                for i in range(inputs.shape[0]):
                    emb.append(
                        jax.lax.dynamic_slice_in_dim(self.weights[0],
                                                     state[i],
                                                     inputs.shape[1],
                                                     axis=0))
                self.state = state + inputs.shape[1]
                return inputs + jnp.stack(emb, 0)
Exemplo n.º 7
0
  def _get_embeddings(self, t):
    """Get embeddings float[..., num_features].

    Args:
      t: int[...] position (i.e. jnp.arange(..., jnp.int32))

    Returns:
      embeddings: float[..., num_features]
    """
    inter_bin_idx, intra_bin_idx = divmod(t, self._time_bin_length)
    bin_parity = inter_bin_idx % 2
    bin_fraction = intra_bin_idx / self._time_bin_length
    embeddings = jnp.stack([
        1 / (1 + inter_bin_idx),
        bin_fraction,
        bin_parity.astype(jnp.float32),
    ], -1)

    assert embeddings.shape == t.shape + (self.num_features,), embeddings.shape
    return embeddings
Exemplo n.º 8
0
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2):
    """Splits a key into a stream of random keys.

  This uses the little-endian counter mode.

  Args:
    key: uint32[2] the key to split
    lo: the range to start extracting from
    hi: the range to stop extracting from

  Returns:
    keys: uint32[hi - lo, 2] the split keys
  """
    if not (key.shape == (2, ) and key.dtype == np.uint32):
        raise ValueError('key must be uint32[2]')
    if not hi < 2**32:
        # You shouldn't really be using more than half the key size anyways.
        raise NotImplementedError('only 32-bit sizes are supported')
    # Create a 64-bit counter:
    i_lo = np.arange(lo, hi, dtype=np.uint32)
    i_hi = np.zeros_like(i_lo)
    i = np.stack([i_lo, i_hi], axis=-1)
    return threefry_2x32_prf(key, i)
Exemplo n.º 9
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 id_to_mask=None,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._has_weights = has_weights
        self._id_to_mask = id_to_mask
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask)
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if self._inputs.input_shape and isinstance(self._inputs.input_shape[0],
                                                   (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.input_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(self._inputs.input_shape))
        # Same for targets.
        if self._inputs.target_shape and isinstance(
                self._inputs.target_shape[0], (list, tuple)):
            model_target_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.target_shape)
        else:
            model_target_shape = tuple([None] +
                                       list(self._inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = math.nested_map(lambda x: x or 1,
                                            model_input_shape)
        model_target_shape = math.nested_map(lambda x: x or 1,
                                             model_target_shape)

        def new_opt_state_and_model_state(shape_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            shapes, dtypes = shape_dtype
            input_signature = tuple(
                ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = tl.Serial(model(mode='train'), loss_fn)
            m._set_rng_recursive(rng)  # pylint: disable=protected-access
            weights, state = m.init(input_signature)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state, static_argnums=(0, ))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                self._inputs.example_shape_dtype, init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [
            self._metrics_dict[m](has_weights=self._has_weights,
                                  id_to_mask=self._id_to_mask)
            for m in self._metrics
        ]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Exemplo n.º 10
0
def pad_trajectories(trajectories, boundary=20):
    """Pad trajectories to a bucket length that is a multiple of boundary.

  Args:
    trajectories: list[(observation, actions, rewards)], where each observation
      is shaped (t+1,) + OBS and actions & rewards are shaped (t,), with the
      length of the list being B (batch size).
    boundary: int, bucket length, the actions and rewards are padded to integer
      multiples of boundary.

  Returns:
    tuple: (padding lengths, reward_mask, padded_observations, padded_actions,
        padded_rewards) where padded_observations is shaped (B, T+1) + OBS and
        padded_actions, padded_rewards & reward_mask are shaped (B, T).
        Where T is max(t) rounded up to an integer multiple of boundary.
        padded_length is how much padding we've added and
        reward_mask is 1s for actual rewards and 0s for the padding.
  """

    # Let's compute max(t) over all trajectories.
    t_max = max(r.shape[0] for (_, _, r, _) in trajectories)

    # t_max is rounded to the next multiple of `boundary`
    boundary = int(boundary)
    bucket_length = boundary * int(np.ceil(float(t_max) / boundary))

    # So all obs will be padded to t_max + 1 and actions and rewards to t_max.
    padded_observations = []
    padded_actions = []
    padded_rewards = []
    padded_infos = collections.defaultdict(list)
    padded_lengths = []
    reward_masks = []

    for (o, a, r, i) in trajectories:
        # Determine the amount to pad, this holds true for obs, actions and rewards.
        num_to_pad = bucket_length + 1 - o.shape[0]
        padded_lengths.append(num_to_pad)
        if num_to_pad == 0:
            padded_observations.append(o)
            padded_actions.append(a)
            padded_rewards.append(r)
            reward_masks.append(onp.ones_like(r, dtype=np.int32))
            if i:
                for k, v in i.items():
                    padded_infos[k].append(v)
            continue

        # First pad observations.
        padding_config = tuple([(0, num_to_pad, 0)] + [(0, 0, 0)] *
                               (o.ndim - 1))

        padding_value = get_padding_value(o.dtype)
        action_padding_value = get_padding_value(a.dtype)
        reward_padding_value = get_padding_value(r.dtype)

        padded_obs = lax.pad(o, padding_value, padding_config)
        padded_observations.append(padded_obs)

        # Now pad actions and rewards.
        padding_config = tuple([(0, num_to_pad, 0)] + [(0, 0, 0)] *
                               (a.ndim - 1))
        padded_action = lax.pad(a, action_padding_value, padding_config)
        padded_actions.append(padded_action)

        assert r.ndim == 1
        padding_config = ((0, num_to_pad, 0), )
        padded_reward = lax.pad(r, reward_padding_value, padding_config)
        padded_rewards.append(padded_reward)

        # Also create the mask to use later.
        reward_mask = onp.ones_like(r, dtype=np.int64)
        reward_masks.append(lax.pad(reward_mask, 0, padding_config))

        if i:
            for k, v in i.items():
                # Create a padding configuration for this value.
                padding_config = [(0, num_to_pad, 0)
                                  ] + [(0, 0, 0)] * (v.ndim - 1)
                padded_infos[k].append(lax.pad(v, 0.0, tuple(padding_config)))

    # Now stack these padded_infos if they exist.
    stacked_padded_infos = None
    if padded_infos:
        stacked_padded_infos = {
            k: np.stack(v)
            for k, v in padded_infos.items()
        }

    return padded_lengths, np.stack(reward_masks), np.stack(
        padded_observations), np.stack(padded_actions), np.stack(
            padded_rewards), stacked_padded_infos
Exemplo n.º 11
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        if metrics is not None:
            self._metrics_dict = metrics
        else:
            self._metrics_dict = _DEFAULT_METRICS
            self._metrics_dict['loss'] = loss_fn
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if math.backend_name() == 'jax':
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state)
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)