Beispiel #1
0
def SumLearnedPick(positions):
    """Get a pair (vec, pos) and pick new pos."""
    succ_keys = positions[:-1, :]
    succ_values = positions[1:, :]
    subtract_1_keys = positions[1:, :]
    subtract_1_values = positions[:-1, :]
    l = int(positions.shape[0]) // 2
    add_keys = np.array([
        np.concatenate([positions[i, :], positions[j, :]]) for i in range(l)
        for j in range(l)
    ])
    add_values = np.array(
        [positions[i + j, :] for i in range(l) for j in range(l)])
    # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
    sub_keys = np.array([
        np.concatenate([positions[i, :], positions[j, :]]) for j in range(l)
        for i in range(l)
    ])
    sub_values = np.array(
        [positions[max(i - j, 0), :] for j in range(l) for i in range(l)])
    return tl.Serial(
        Dup2(), Dup2(), Dup2(), Dup2(),
        tl.Parallel(
            LearnedQP(),
            LearnedQP(keys=succ_keys, values=succ_values),
            LearnedQP(keys=subtract_1_keys, values=subtract_1_values),
            LearnedQP(keys=add_keys, values=add_values, binary=True),
            LearnedQP(keys=sub_keys, values=sub_values, binary=True),
        ), Softmax5Branches(n_branches=5))
 def test_fn_layer_example(self):
     layer = cb.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0)))
     input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7)))
     expected_shape = ((2, 7), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
     inp = (np.array([2]), np.array([3]))
     x, xs = layer(inp)
     self.assertEqual(int(x), 5)
     self.assertEqual([int(y) for y in xs], [2, 3])
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs):
  """Query a table with a position vector."""
  if keys is None:
    return x
  k = np.array(keys)
  v = np.array(values)
  q = x
  if binary:
    q = np.concatenate([x, x], axis=-1)
  return tl.DotProductAttention(q, k, v, None, 0.0, None, None)
Beispiel #4
0
 def test_scan_basic(self):
   @base.layer(n_in=2, n_out=2)
   def add(x, **unused_kwargs):
     res = x[0] + x[1]
     return res, res
   scan_layer = cb.Scan(add())  # pylint: disable=no-value-for-parameter
   input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7)))
   expected_shape = ((3, 2, 7), (2, 7))
   output_shape = base.check_shape_agreement(scan_layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
   inp = (np.array([1, 2, 3]), np.array(0))
   o, v = scan_layer(inp)
   self.assertEqual(int(v), 6)
   self.assertEqual([int(x) for x in o], [1, 3, 6])
def NewPositionalEncoding(x, positions=None, **kwargs):
  """Implements new positional encoding."""
  del kwargs
  x_length = np.shape(x)[1]
  pos = np.array(positions)[np.newaxis, :x_length, :]
  pos += np.zeros((np.shape(x)[0], 1, 1))  # Broadcast on batch.
  return pos
Beispiel #6
0
def _multi_device_put(x, devices=None):
    """Memory efficient multi-device replication / broadcast in JAX.

  JAX uses a ShardedDeviceArray class that holds a list of device buffers
  on separate devices for use with pmap'd computations.  Sharded arrays
  are explicitly used to eliminate unneccessary inter-device transfer of
  memory buffers between use in pmap'd computations.  The JAX API currently
  does not have a multi-device 'put' function that copies a buffer onto
  N devices in a memory-efficient fashion, so we implement our own here.

  Args:
    x: jax DeviceArray or numpy ndarray to be replicated.
    devices: a jax.devices() list or subset thereof of devices to
      replicate onto.  Should match the list passed to any pmaps
      ingesting the replicated array.

  Returns:
    A ShardedDeviceArray with
    dtype = x.dtype and shape = (n_devices,) + x.shape
    that's backed by replicated device_buffers on each local device.
  """
    # Convert _FilledConstants that don't have device_buffer, etc.
    if type(x) != jax.xla.DeviceArray:  # pylint: disable=unidiomatic-typecheck
        x = np.array(x)
    # Calculate the abstract shape of the replicated array.
    if not devices:
        devices = jax.local_devices()
    n_devices = len(devices)
    x_aval = jax.xla.abstractify(x)
    broadcast_x_aval = jax.abstract_arrays.ShapedArray(
        (n_devices, ) + x_aval.shape, x_aval.dtype)
    # Create copies of the underlying device buffer for each local device.
    broadcast_buffers = [jax.device_put(x, dv).device_buffer for dv in devices]
    return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
Beispiel #7
0
def one_hot(x, size, dtype=np.float32):  # pylint: disable=invalid-name
  """Make a n+1 dim one-hot array from n dim int-categorical array."""
  arange_size = np.arange(size)
  if backend.get_name() == 'jax':
    # Work around a jax broadcasting issue.
    arange_size = jax.lax.tie_in(x, arange_size)
  return np.array(x[..., np.newaxis] == arange_size, dtype)
Beispiel #8
0
 def test_from_file(self):
     params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]])
     filename = self.create_tempfile('params.npy').full_path
     with open(filename, 'wb') as f:
         np.save(f, params)
     initializer = initializers.InitializerFromFile(filename)
     input_shape = (3, 2)
     init_value = initializer(input_shape, random.get_prng(0))
     self.assertEqual('%s' % init_value, '%s' % params)
Beispiel #9
0
def Fn(f, n_in=None, n_out=None):  # pylint: disable=invalid-name
    """Returns a layer with no weights that applies the function f.

  The function f can take and return any number of arguments, but it cannot
  have default arguments or keywords arguments. It can use numpy though, e.g:

  A layer that takes 2 arguments and returns sum and concatenation on stack:

    Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0)))

  Sometimes determining the number of outputs automatically fails,
  in such cases specify n_in and n_out.

  Args:
    f: the function to execute
    n_in: optional, number of inputs
    n_out: optional, number of outputs

  Returns:
    A layer executing the function f.
  """
    # Inspect the function f to restrict to no-defaults and no-kwargs functions.
    if six.PY2:
        argspec = inspect.getargspec(f)
        varkwargs = argspec.keywords
    else:
        argspec = inspect.getfullargspec(f)
        varkwargs = argspec.varkw
    # This layer cannot handle functions with kwargs or defaults.
    if argspec.defaults is not None:
        raise ValueError('function cannot have default arguments')
    if varkwargs:
        raise ValueError('function cannot have keyword arguments')

    # Determine n_in from function signature if not set.
    if n_in is None:
        if argspec.varargs is not None:
            raise ValueError('n_in is not set and f has variable args')
        n_in = len(argspec.args)
    # Try to determine n_out from function signature.
    if n_out is None:
        try:
            dummy_args = [np.array([[0.0]]) for _ in range(n_in)]
            res = f(*dummy_args)
            n_out = len(res) if isinstance(res, (list, tuple)) else 1
        except:
            raise ValueError('n_out is not set and could not be determined')

    # Create the layer.
    @layer(n_in=n_in, n_out=n_out)
    def F(xs, **unused_kwargs):  # pylint: disable=invalid-name
        if not isinstance(xs, (tuple, list)):
            xs = (xs, )
        return f(*xs)

    return F()  # pylint: disable=no-value-for-parameter
Beispiel #10
0
def multi_device_put(x, devices=None, reuse=True):
    """Memory efficient multi-device replication / broadcast in JAX.

  JAX uses a ShardedDeviceArray class that holds a list of device buffers
  on separate devices for use with pmap'd computations.  Sharded arrays
  are explicitly used to eliminate unneccessary inter-device transfer of
  memory buffers between use in pmap'd computations.  The JAX API currently
  does not have a multi-device 'put' function that copies a buffer onto
  N devices in a memory-efficient fashion, so we implement our own here.

  Args:
    x: jax DeviceArray or numpy ndarray to be replicated.
    devices: a jax.devices() list or subset thereof of devices to
      replicate onto.  Should match the list passed to any pmaps
      ingesting the replicated array.
    reuse: bool. If x is a DeviceArray whether to reuse its backing
      device_buffer in the resulting ShardedDeviceArray.  We do this
      by default to minimize a 2x overhead with large arrays.

  Returns:
    A ShardedDeviceArray with
    dtype = x.dtype and shape = (n_devices,) + x.shape
    that's backed by replicated device_buffers on each local device.
  """
    # Convert _FilledConstants that don't have device_buffer, etc.
    if type(x) != jax.xla.DeviceArray:  # pylint: disable=unidiomatic-typecheck
        x = np.array(x)
    # Calculate the abstract shape of the replicated array.
    if not devices:
        devices = jax.local_devices()
    n_devices = len(devices)
    x_aval = jax.xla.abstractify(x)
    broadcast_x_aval = jax.abstract_arrays.ShapedArray(
        (n_devices, ) + x_aval.shape, x_aval.dtype)
    # Create copies of the underlying device buffer for each local device.
    if reuse:
        # reuse the original device buffer for its device in the sharded
        # device array
        other_devices = [
            dv for dv in devices if dv != x.device_buffer.device()
        ]
        broadcast_buffers = ([
            x.device_buffer,
        ] + [
            jax.xla.xc.Buffer.from_pyval(x, device=dv) for dv in other_devices
        ])
    else:
        # make new copies of the buffer for every local device
        broadcast_buffers = [
            jax.xla.xc.Buffer.from_pyval(x, device=dv) for dv in devices
        ]
    return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
Beispiel #11
0
 def new_weights_and_state(self, input_signature):
     d_feature = input_signature.shape[-1]
     pe = onp.zeros((self._max_len, d_feature), dtype=onp.float32)
     position = onp.arange(0, self._max_len)[:, onp.newaxis]
     div_term = onp.exp(
         onp.arange(0, d_feature, 2) * -(onp.log(10000.0) / d_feature))
     pe[:, 0::2] = onp.sin(position * div_term)
     pe[:, 1::2] = onp.cos(position * div_term)
     pe = pe[onp.newaxis, :, :]  # [1, self._max_len, d_feature]
     weights = np.array(
         pe)  # These are trainable parameters, initialized above.
     state = 0 if self._mode == 'predict' else base.EMPTY_STATE
     return weights, state
Beispiel #12
0
    def __init__(self,
                 mode=None,
                 learn_epsilon=False,
                 init_epsilon=1e-6,
                 init_learnt_epsilon=1e-4):
        super(FilterResponseNorm, self).__init__()

        del mode

        # If we learn epsilon then epsilon = init_epsilon + |learnt_value|
        # where learnt_value is initialized to init_learnt_epsilon.
        # If learn_epsilon is false then epsilon is just init_epsilon.
        #
        # NOTE: I (afrozm) haven't been able to train with `learn_epsilon = True`.
        self._learn_epsilon = learn_epsilon

        assert init_epsilon > 0
        assert init_learnt_epsilon > 0

        self._init_epsilon = np.array(init_epsilon, dtype=np.float32)
        self._init_learnt_epsilon = np.array(init_learnt_epsilon,
                                             dtype=np.float32)
Beispiel #13
0
 def update_model_state(self, key, value):
   """Updates model state based on nontrainable_params."""
   # Translate model state keys to nontrainable param names.
   if key in self._nontrainable_param_map:
     p_name = self._nontrainable_param_map[key]
   else:
     # If a key is not in mapping, it stays the same.
     p_name = key
   if p_name in self.nontrainable_params:
     if self._step == 0:
       log('Mapping model state key {} to nontrainable param {}.'
           ''.format(key, p_name))
       return self._for_n_devices(np.array(self.nontrainable_params[p_name]))
   return value
def PerformPositionOperations(pos, positions=None):
  """Gets pos and returns (q1, ..., q5)."""
  succ_keys = positions[:-1, :]
  succ_values = positions[1:, :]
  subtract_1_keys = positions[1:, :]
  subtract_1_values = positions[:-1, :]
  l = int(positions.shape[0]) // 2
  add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for i in range(l) for j in range(l)])
  add_values = np.array([positions[i + j, :]
                         for i in range(l) for j in range(l)])
  # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)"
  sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]])
                       for j in range(l) for i in range(l)])
  sub_values = np.array([positions[max(i - j, 0), :]
                         for j in range(l) for i in range(l)])
  query_types = [
      QueryPositionKV(),
      QueryPositionKV(keys=succ_keys, values=succ_values),
      QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values),
      QueryPositionKV(keys=add_keys, values=add_values, binary=True),
      QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)]
  return [qt @ pos for qt in query_types]  # pylint: disable=syntax-error
  def _train_step(self, next_train_batch):
    """Run one training step and update self._opt_state."""
    # Calculate the current optimizer parameters.
    # TODO(pkozakowski): Optimizer parameters get polluted with model state,
    # which doesn't break anything but is weird. Filter it out.
    opt_param_updates = layers.nested_map(
        lambda x: self._maybe_replicate(np.array(x)),
        self.nontrainable_params)
    opt_state = self._opt_state
    opt_state.opt_params.update(opt_param_updates)

    # Run the update.
    (weights, slots), self._model_state, self._rngs = self._jit_update_fn(
        self._step, opt_state, next_train_batch, self._model_state, self._rngs)
    self._model_state = self._map_to_state_dicts(self._state_dicts_update)
    self._opt_state = opt_state._replace(weights=weights, slots=slots)
    self._step += 1
  def __init__(self, learning_rate, **init_opt_params):
    """Initialize the optimizer.

    Takes the initial optimizer parameters as positional arguments. They are fed
    back to the optimizer in tree_update, in the same order. They can be changed
    between updates, e.g. for learning rate schedules.

    The constructor should be overridden in derived classes to give names to the
    optimizer parameters, so the gin configuration can set them.

    Args:
      learning_rate: The initial learning rate.
      **init_opt_params: Initial values of any additional optimizer parameters.
    """
    init_opt_params['learning_rate'] = learning_rate
    self._init_opt_params = {
        name: np.array(value) for (name, value) in init_opt_params.items()
    }
Beispiel #17
0
 def new_params_and_state(self, input_shape, input_dtype, rng):
     del input_shape, input_dtype, rng
     params = ()
     state = {self._name: np.array(self._initial_rate)}
     return params, state
Beispiel #18
0
def PolicySchedule(
    history,
    observation_metrics=(
        ("train", "metrics/accuracy"),
        ("train", "metrics/loss"),
        ("eval", "metrics/accuracy"),
        ("eval", "metrics/loss"),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ("learning_rate", 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {"name": float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, "Building observations took %0.2f sec.",
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    assert policy_and_value_vocab_size is None, (
        "Serialized policies are not supported yet.")
    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_controls=len(control_configs),
        n_actions=len(action_multipliers),
        vocab_size=policy_and_value_vocab_size,
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, "Building the policy network took %0.2f sec.",
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step)
    (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, "Policy checkpoint not found."
    (params, _) = opt_state
    logging.vlog(1, "Restoring the policy parameters took %0.2f sec.",
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()
    # ((log_probs, value_preds), state). We have no way to pass state to the next
    # step, but that should be fine.
    (log_probs, _) = (net(np.array([observations]),
                          params=params,
                          state=state,
                          rng=rng))
    logging.vlog(1, "Running the policy took %0.2f sec.",
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs) * observations.shape[0],
                               len(action_multipliers))
    action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] /
                                 temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls
Beispiel #19
0
 def new_weights_and_state(self, input_signature):
   del input_signature
   state = {self._name: np.array(self._initial_rate)}
   return base.EMPTY_WEIGHTS, state
Beispiel #20
0
 def _decay_rate_pow(i, exponent=0.8):
     """Default Adafactor second-moment decay schedule."""
     t = np.array(i, np.float32) + 1.0
     return 1.0 - t**(-exponent)
def PolicySchedule(
    history,
    observation_metrics=(
        ('train', 'metrics/accuracy'),
        ('train', 'metrics/loss'),
        ('eval', 'metrics/accuracy'),
        ('eval', 'metrics/loss'),
    ),
    include_controls_in_observation=False,
    control_configs=(
        # (name, start, (low, high), flip)
        ('learning_rate', 1e-3, (1e-9, 10.0), False), ),
    observation_range=(0.0, 10.0),
    action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5),
    policy_and_value_model=trax_models.FrameStackMLP,
    policy_and_value_two_towers=False,
    policy_and_value_vocab_size=None,
    policy_dir=gin.REQUIRED,
    temperature=1.0,
):
    """Learning rate schedule controlled by a learned policy.

  Args:
    history: the history of training and evaluation (History object).
    observation_metrics: list of pairs (mode, metric), as in the History object.
    include_controls_in_observation: bool, whether to include the controls in
      observations.
    control_configs: control configs, see trax.rl.envs.OnlineTuneEnv.
    observation_range: tuple (low, high), range to clip the metrics to.
    action_multipliers: sequence of LR multipliers that policy actions
      correspond to.
    policy_and_value_model: Trax model to use as the policy.
    policy_and_value_two_towers: bool, whether the action distribution and value
      prediction is computed by separate model towers.
    policy_and_value_vocab_size: vocabulary size of a policy and value network
      operating on serialized representation. If None, use raw continuous
      representation.
    policy_dir: directory with the policy checkpoint.
    temperature: temperature for sampling from the policy.

  Returns:
    a function nontrainable_params(step): float -> {'name': float}, the
    step-dependent schedule for nontrainable parameters.
  """

    # Turn the history into observations for the policy. If we don't have any,
    # return the initial learning rate.
    start_time = time.time()
    observations = online_tune.history_to_observations(
        history, observation_metrics, observation_range,
        control_configs if include_controls_in_observation else None)
    logging.vlog(1, 'Building observations took %0.2f sec.',
                 time.time() - start_time)
    if observations.shape[0] == 0:
        controls = {
            name: start_value
            for (name, start_value, _, _) in control_configs
        }
        return lambda _: controls

    # Build the policy network and load its parameters.
    start_time = time.time()
    net = ppo.policy_and_value_net(
        n_controls=len(control_configs),
        n_actions=len(action_multipliers),
        vocab_size=policy_and_value_vocab_size,
        bottom_layers_fn=policy_and_value_model,
        two_towers=policy_and_value_two_towers,
    )
    logging.vlog(1, 'Building the policy network took %0.2f sec.',
                 time.time() - start_time)
    start_time = time.time()
    # (opt_state, state, epoch, opt_step, history)
    (opt_state, state, _, _, _) = ppo.maybe_restore_opt_state(policy_dir)
    assert opt_state is not None, 'Policy checkpoint not found.'
    (params, _, _) = opt_state
    logging.vlog(1, 'Restoring the policy parameters took %0.2f sec.',
                 time.time() - start_time)

    # Run the policy and sample an action.
    seed = random.randint(0, 2**31 - 1)
    rng = jax_random.get_prng(seed=seed)
    start_time = time.time()

    (low, high) = observation_range
    observation_space = gym.spaces.Box(shape=observations.shape[1:],
                                       low=low,
                                       high=high)
    action_space = gym.spaces.MultiDiscrete(nvec=(len(action_multipliers), ) *
                                            len(control_configs))
    n_timesteps = observations.shape[0]
    rewards_to_actions = ppo.init_rewards_to_actions(
        policy_and_value_vocab_size, observation_space, action_space,
        n_timesteps)
    # (log_probs, value_preds, state, rng)
    (log_probs, _, _, _) = ppo.run_policy(
        policy_and_value_net_apply=net,
        observations=np.array([observations]),
        lengths=np.array([n_timesteps]),
        weights=params,
        state=state,
        rng=rng,
        vocab_size=policy_and_value_vocab_size,
        observation_space=observation_space,
        action_space=action_space,
        rewards_to_actions=rewards_to_actions,
    )

    logging.vlog(1, 'Running the policy took %0.2f sec.',
                 time.time() - start_time)
    # Sample from the action distribution for the last timestep.
    assert log_probs.shape == (1, len(control_configs),
                               len(action_multipliers))
    action = utils.gumbel_sample(log_probs[0] / temperature)

    # Get new controls.
    controls = {
        # name: value
        control_config[0]: online_tune.update_control(  # pylint: disable=g-complex-comprehension
            control_config, control_action, history, action_multipliers)
        for (control_action, control_config) in zip(action, control_configs)
    }
    return lambda _: controls
 def _state_dicts_update(self, state_dict):
   assert len(state_dict.keys()) == 1
   key = list(state_dict.keys())[0]
   value = np.array(state_dict[key])
   return {key: np.array(self.update_model_state(key, value))}