Ejemplo n.º 1
0
 def update_model_state(self, key, value):
     """Updates model state based on nontrainable_params."""
     # Translate model state keys to nontrainable param names.
     if key in self._nontrainable_param_map:
         param_name = self._nontrainable_param_map[key]
     else:
         # If a key is not in mapping, it stays the same.
         param_name = key
     if param_name in self.nontrainable_params:
         if self._step == 0:
             log("Mapping model state key {} to nontrainable param {}.".
                 format(key, param_name))
             return self._maybe_replicate(
                 np.array(self.nontrainable_params[param_name]))
     return value
def _chunked_positional_encoding_new_params(input_shape, rng, max_len=2048):  # pylint: disable=invalid-name
  """Helper: create positional encoding parameters."""
  del rng
  # Check if we are operating on chunked inputs by checking if the first
  # shape is a list/tuple of shapes (otherwise it's an int or numpy array).
  is_chunked = isinstance(input_shape[0], (list, tuple))
  feature_depth = input_shape[0][-1] if is_chunked else input_shape[-1]
  pe = onp.zeros((max_len, feature_depth), dtype=onp.float32)
  position = onp.arange(0, max_len)[:, onp.newaxis]
  div_term = onp.exp(
      onp.arange(0, feature_depth, 2) * -(onp.log(10000.0) / feature_depth))
  pe[:, 0::2] = onp.sin(position * div_term)
  pe[:, 1::2] = onp.cos(position * div_term)
  pe = pe[onp.newaxis, :, :]  # [1, max_len, feature_depth]
  return np.array(pe)  # These are trainable parameters, initialized as above.
Ejemplo n.º 3
0
    def _train_step(self, next_train_batch):
        """Run one training step and update self._opt_state."""
        # Calculate the current learning rate.
        learning_rate = self._maybe_replicate(np.array(self.learning_rate))
        opt_state = self._opt_state
        opt_params = opt_state.opt_params
        opt_params = (learning_rate, ) + opt_params[1:]
        opt_state = opt_state._replace(opt_params=opt_params)

        # Run the update.
        (params, slots), self._model_state, self._rngs = self._jit_update_fn(
            self._step, opt_state, next_train_batch, self._model_state,
            self._rngs)
        self._opt_state = opt_state._replace(params=params, slots=slots)
        self._step += 1
Ejemplo n.º 4
0
    def _train_step(self, next_train_batch):
        """Run one training step and update self._opt_state."""
        # Calculate the current learning rate.
        opt_param_updates = layers.nested_map(
            self.optimizer_params,
            lambda x: self._maybe_replicate(np.array(x)))
        opt_state = self._opt_state
        opt_state.opt_params.update(opt_param_updates)

        # Run the update.
        (params, slots), self._model_state, self._rngs = self._jit_update_fn(
            self._step, opt_state, next_train_batch, self._model_state,
            self._rngs)
        self._model_state = self._map_to_state_dicts(self._state_dicts_update)
        self._opt_state = opt_state._replace(params=params, slots=slots)
        self._step += 1
Ejemplo n.º 5
0
def _positional_encoding_new_params(  # pylint: disable=invalid-name
        input_shape,
        input_dtype,
        rng,
        max_len=2048):
    """Helper: create positional encoding parameters."""
    del input_dtype, rng
    d_feature = input_shape[-1]
    pe = onp.zeros((max_len, d_feature), dtype=onp.float32)
    position = onp.arange(0, max_len)[:, onp.newaxis]
    div_term = onp.exp(
        onp.arange(0, d_feature, 2) * -(onp.log(10000.0) / d_feature))
    pe[:, 0::2] = onp.sin(position * div_term)
    pe[:, 1::2] = onp.cos(position * div_term)
    pe = pe[onp.newaxis, :, :]  # [1, max_len, d_feature]
    return np.array(
        pe)  # These are trainable parameters, initialized as above.
Ejemplo n.º 6
0
  def _train_step(self, next_train_batch):
    """Run one training step and update self._opt_state."""
    # Calculate the current optimizer parameters.
    # TODO(pkozakowski): Optimizer parameters get polluted with model state,
    # which doesn't break anything but is weird. Filter it out.
    opt_param_updates = layers.nested_map(
        self.nontrainable_params, lambda x: self._maybe_replicate(np.array(x))
    )
    opt_state = self._opt_state
    opt_state.opt_params.update(opt_param_updates)

    # Run the update.
    (params, slots), self._model_state, self._rngs = self._jit_update_fn(
        self._step, opt_state, next_train_batch, self._model_state, self._rngs)
    self._model_state = self._map_to_state_dicts(self._state_dicts_update)
    self._opt_state = opt_state._replace(params=params, slots=slots)
    self._step += 1
Ejemplo n.º 7
0
    def __init__(self, learning_rate, **init_opt_params):
        """Initialize the optimizer.

    Takes the initial optimizer parameters as positional arguments. They are fed
    back to the optimizer in tree_update, in the same order. They can be changed
    between updates, e.g. for learning rate schedules.

    The constructor should be overridden in derived classes to give names to the
    optimizer parameters, so the gin configuration can set them.

    Args:
      learning_rate: The initial learning rate.
      **init_opt_params: Initial values of any additional optimizer parameters.
    """
        init_opt_params["learning_rate"] = learning_rate
        self._init_opt_params = {
            name: np.array(value)
            for (name, value) in init_opt_params.items()
        }
Ejemplo n.º 8
0
def multi_device_put(x, devices=None, reuse=True):
    """Memory efficient multi-device replication in JAX.

  Args:
    x: jax DeviceArray or numpy ndarray to be replicated.
    devices: a jax.devices() list or subset thereof of devices to
      replicate onto.  Should match the list passed to any pmaps
      ingesting the replicated array.
    reuse: bool. If x is a DeviceArray whether to reuse its backing
      device_buffer in the resulting ShardedDeviceArray.

  Returns:
    A ShardedDeviceArray with dtype = x.dtype and shape =
    (n_devices,) + x.shape that's backed by replica
    device_buffers on each device.
  """
    # Convert _FilledConstants that don't have device_buffer, etc.
    if type(x) != jax.xla.DeviceArray:  # pylint: disable=unidiomatic-typecheck
        x = np.array(x)
    if not devices:
        devices = jax.devices()
    n_devices = len(devices)
    x_aval = jax.xla.abstractify(x)
    broadcast_x_aval = jax.abstract_arrays.ShapedArray(
        (n_devices, ) + x_aval.shape, x_aval.dtype)
    if reuse:
        other_device_ordinals = [
            dv.id for dv in jax.devices() if dv != x.device_buffer.device()
        ]
        broadcast_buffers = ([
            x.device_buffer,
        ] + [
            jax.xla.xc.Buffer.from_pyval(x, device=i)
            for i in other_device_ordinals
        ])
    else:
        broadcast_buffers = [
            jax.xla.xc.Buffer.from_pyval(x, device=i) for i in range(n_devices)
        ]
    return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
Ejemplo n.º 9
0
 def _state_dicts_update(self, state_dict):
     assert len(state_dict.keys()) == 1
     key = list(state_dict.keys())[0]
     value = np.array(state_dict[key])
     return {key: np.array(self.update_model_state(key, value))}
Ejemplo n.º 10
0
def Relu(x, **unused_kwargs):
    return np.maximum(x, np.array(0, dtype=x.dtype))
Ejemplo n.º 11
0
def one_hot(x, size, dtype=np.float32):
    """Make a n+1 dim one-hot array from n dim int-categorical array."""
    return np.array(x[..., np.newaxis] == np.arange(size), dtype)
Ejemplo n.º 12
0
 def _decay_rate_pow(i, exponent=0.8):
     """Default Adafactor second-moment decay schedule."""
     t = np.array(i, np.float32) + 1.0
     return 1.0 - t**(-exponent)
Ejemplo n.º 13
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_lr_in_observation=False,
    observation_range=(0.0, 5.0),
    start_lr=0.001,
    max_lr=10.0,
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_dir=gin.REQUIRED,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_lr_in_observation: bool, whether to include the learning rate in
      observations.
    observation_range: tuple (low, high), range to clip the observation to.
    start_lr: starting learning rate.
    max_lr: maximum value to clip the learning rate to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_dir: directory with the policy checkpoint.

  Returns:
    a function learning_rate(step): float -> float, the step-dependent lr.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        include_lr_in_observation)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        return lambda _: start_lr

    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_actions=len(action_multipliers),
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    ((log_probs, _), _) = net(np.array([observations]), params, state, rng=rng)
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    action = utils.gumbel_sample(log_probs[0, -1, :])

    # Get a new learning rate.
    new_lr = online_tune.new_learning_rate(action, history, action_multipliers,
                                           max_lr)
    return lambda _: new_lr
Ejemplo n.º 14
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ("learning_rate", 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {"name": float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    assert policy_and_value_vocab_size is None, (
        "Serialized policies are not supported yet.")
    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_controls=len(control_configs),
        n_actions=len(action_multipliers),
        vocab_size=policy_and_value_vocab_size,
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    (log_probs, _) = (net(np.array([observations]),
                          params=params,
                          state=state,
                          rng=rng))
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs) * observations.shape[0],
                               len(action_multipliers))
    action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] /
                                 temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls
Ejemplo n.º 15
0
 def new_params_and_state(self, input_shape, input_dtype, rng):
     del input_shape, input_dtype, rng
     params = ()
     state = {self._name: np.array(self._initial_rate)}
     return params, state
Ejemplo n.º 16
0
 def new_parameters(self, input_shape, input_dtype, rng):
     """Initialize dropout parameters and state."""
     del input_shape, input_dtype, rng
     return (), {self._name: np.array(self._initial_rate)}