def SumLearnedPick(positions): """Get a pair (vec, pos) and pick new pos.""" succ_keys = positions[:-1, :] succ_values = positions[1:, :] subtract_1_keys = positions[1:, :] subtract_1_values = positions[:-1, :] l = int(positions.shape[0]) // 2 add_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for i in range(l) for j in range(l) ]) add_values = np.array( [positions[i + j, :] for i in range(l) for j in range(l)]) # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" sub_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for j in range(l) for i in range(l) ]) sub_values = np.array( [positions[max(i - j, 0), :] for j in range(l) for i in range(l)]) return tl.Serial( Dup2(), Dup2(), Dup2(), Dup2(), tl.Parallel( LearnedQP(), LearnedQP(keys=succ_keys, values=succ_values), LearnedQP(keys=subtract_1_keys, values=subtract_1_values), LearnedQP(keys=add_keys, values=add_values, binary=True), LearnedQP(keys=sub_keys, values=sub_values, binary=True), ), Softmax5Branches(n_branches=5))
def test_fn_layer_example(self): layer = cb.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0))) input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7))) expected_shape = ((2, 7), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([2]), np.array([3])) x, xs = layer(inp) self.assertEqual(int(x), 5) self.assertEqual([int(y) for y in xs], [2, 3])
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs): """Query a table with a position vector.""" if keys is None: return x k = np.array(keys) v = np.array(values) q = x if binary: q = np.concatenate([x, x], axis=-1) return tl.DotProductAttention(q, k, v, None, 0.0, None, None)
def test_scan_basic(self): @base.layer(n_in=2, n_out=2) def add(x, **unused_kwargs): res = x[0] + x[1] return res, res scan_layer = cb.Scan(add()) # pylint: disable=no-value-for-parameter input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7))) expected_shape = ((3, 2, 7), (2, 7)) output_shape = base.check_shape_agreement(scan_layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([1, 2, 3]), np.array(0)) o, v = scan_layer(inp) self.assertEqual(int(v), 6) self.assertEqual([int(x) for x in o], [1, 3, 6])
def NewPositionalEncoding(x, positions=None, **kwargs): """Implements new positional encoding.""" del kwargs x_length = np.shape(x)[1] pos = np.array(positions)[np.newaxis, :x_length, :] pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch. return pos
def _multi_device_put(x, devices=None): """Memory efficient multi-device replication / broadcast in JAX. JAX uses a ShardedDeviceArray class that holds a list of device buffers on separate devices for use with pmap'd computations. Sharded arrays are explicitly used to eliminate unneccessary inter-device transfer of memory buffers between use in pmap'd computations. The JAX API currently does not have a multi-device 'put' function that copies a buffer onto N devices in a memory-efficient fashion, so we implement our own here. Args: x: jax DeviceArray or numpy ndarray to be replicated. devices: a jax.devices() list or subset thereof of devices to replicate onto. Should match the list passed to any pmaps ingesting the replicated array. Returns: A ShardedDeviceArray with dtype = x.dtype and shape = (n_devices,) + x.shape that's backed by replicated device_buffers on each local device. """ # Convert _FilledConstants that don't have device_buffer, etc. if type(x) != jax.xla.DeviceArray: # pylint: disable=unidiomatic-typecheck x = np.array(x) # Calculate the abstract shape of the replicated array. if not devices: devices = jax.local_devices() n_devices = len(devices) x_aval = jax.xla.abstractify(x) broadcast_x_aval = jax.abstract_arrays.ShapedArray( (n_devices, ) + x_aval.shape, x_aval.dtype) # Create copies of the underlying device buffer for each local device. broadcast_buffers = [jax.device_put(x, dv).device_buffer for dv in devices] return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
def one_hot(x, size, dtype=np.float32): # pylint: disable=invalid-name """Make a n+1 dim one-hot array from n dim int-categorical array.""" arange_size = np.arange(size) if backend.get_name() == 'jax': # Work around a jax broadcasting issue. arange_size = jax.lax.tie_in(x, arange_size) return np.array(x[..., np.newaxis] == arange_size, dtype)
def test_from_file(self): params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) filename = self.create_tempfile('params.npy').full_path with open(filename, 'wb') as f: np.save(f, params) initializer = initializers.InitializerFromFile(filename) input_shape = (3, 2) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual('%s' % init_value, '%s' % params)
def Fn(f, n_in=None, n_out=None): # pylint: disable=invalid-name """Returns a layer with no weights that applies the function f. The function f can take and return any number of arguments, but it cannot have default arguments or keywords arguments. It can use numpy though, e.g: A layer that takes 2 arguments and returns sum and concatenation on stack: Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0))) Sometimes determining the number of outputs automatically fails, in such cases specify n_in and n_out. Args: f: the function to execute n_in: optional, number of inputs n_out: optional, number of outputs Returns: A layer executing the function f. """ # Inspect the function f to restrict to no-defaults and no-kwargs functions. if six.PY2: argspec = inspect.getargspec(f) varkwargs = argspec.keywords else: argspec = inspect.getfullargspec(f) varkwargs = argspec.varkw # This layer cannot handle functions with kwargs or defaults. if argspec.defaults is not None: raise ValueError('function cannot have default arguments') if varkwargs: raise ValueError('function cannot have keyword arguments') # Determine n_in from function signature if not set. if n_in is None: if argspec.varargs is not None: raise ValueError('n_in is not set and f has variable args') n_in = len(argspec.args) # Try to determine n_out from function signature. if n_out is None: try: dummy_args = [np.array([[0.0]]) for _ in range(n_in)] res = f(*dummy_args) n_out = len(res) if isinstance(res, (list, tuple)) else 1 except: raise ValueError('n_out is not set and could not be determined') # Create the layer. @layer(n_in=n_in, n_out=n_out) def F(xs, **unused_kwargs): # pylint: disable=invalid-name if not isinstance(xs, (tuple, list)): xs = (xs, ) return f(*xs) return F() # pylint: disable=no-value-for-parameter
def multi_device_put(x, devices=None, reuse=True): """Memory efficient multi-device replication / broadcast in JAX. JAX uses a ShardedDeviceArray class that holds a list of device buffers on separate devices for use with pmap'd computations. Sharded arrays are explicitly used to eliminate unneccessary inter-device transfer of memory buffers between use in pmap'd computations. The JAX API currently does not have a multi-device 'put' function that copies a buffer onto N devices in a memory-efficient fashion, so we implement our own here. Args: x: jax DeviceArray or numpy ndarray to be replicated. devices: a jax.devices() list or subset thereof of devices to replicate onto. Should match the list passed to any pmaps ingesting the replicated array. reuse: bool. If x is a DeviceArray whether to reuse its backing device_buffer in the resulting ShardedDeviceArray. We do this by default to minimize a 2x overhead with large arrays. Returns: A ShardedDeviceArray with dtype = x.dtype and shape = (n_devices,) + x.shape that's backed by replicated device_buffers on each local device. """ # Convert _FilledConstants that don't have device_buffer, etc. if type(x) != jax.xla.DeviceArray: # pylint: disable=unidiomatic-typecheck x = np.array(x) # Calculate the abstract shape of the replicated array. if not devices: devices = jax.local_devices() n_devices = len(devices) x_aval = jax.xla.abstractify(x) broadcast_x_aval = jax.abstract_arrays.ShapedArray( (n_devices, ) + x_aval.shape, x_aval.dtype) # Create copies of the underlying device buffer for each local device. if reuse: # reuse the original device buffer for its device in the sharded # device array other_devices = [ dv for dv in devices if dv != x.device_buffer.device() ] broadcast_buffers = ([ x.device_buffer, ] + [ jax.xla.xc.Buffer.from_pyval(x, device=dv) for dv in other_devices ]) else: # make new copies of the buffer for every local device broadcast_buffers = [ jax.xla.xc.Buffer.from_pyval(x, device=dv) for dv in devices ] return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
def new_weights_and_state(self, input_signature): d_feature = input_signature.shape[-1] pe = onp.zeros((self._max_len, d_feature), dtype=onp.float32) position = onp.arange(0, self._max_len)[:, onp.newaxis] div_term = onp.exp( onp.arange(0, d_feature, 2) * -(onp.log(10000.0) / d_feature)) pe[:, 0::2] = onp.sin(position * div_term) pe[:, 1::2] = onp.cos(position * div_term) pe = pe[onp.newaxis, :, :] # [1, self._max_len, d_feature] weights = np.array( pe) # These are trainable parameters, initialized above. state = 0 if self._mode == 'predict' else base.EMPTY_STATE return weights, state
def __init__(self, mode=None, learn_epsilon=False, init_epsilon=1e-6, init_learnt_epsilon=1e-4): super(FilterResponseNorm, self).__init__() del mode # If we learn epsilon then epsilon = init_epsilon + |learnt_value| # where learnt_value is initialized to init_learnt_epsilon. # If learn_epsilon is false then epsilon is just init_epsilon. # # NOTE: I (afrozm) haven't been able to train with `learn_epsilon = True`. self._learn_epsilon = learn_epsilon assert init_epsilon > 0 assert init_learnt_epsilon > 0 self._init_epsilon = np.array(init_epsilon, dtype=np.float32) self._init_learnt_epsilon = np.array(init_learnt_epsilon, dtype=np.float32)
def update_model_state(self, key, value): """Updates model state based on nontrainable_params.""" # Translate model state keys to nontrainable param names. if key in self._nontrainable_param_map: p_name = self._nontrainable_param_map[key] else: # If a key is not in mapping, it stays the same. p_name = key if p_name in self.nontrainable_params: if self._step == 0: log('Mapping model state key {} to nontrainable param {}.' ''.format(key, p_name)) return self._for_n_devices(np.array(self.nontrainable_params[p_name])) return value
def PerformPositionOperations(pos, positions=None): """Gets pos and returns (q1, ..., q5).""" succ_keys = positions[:-1, :] succ_values = positions[1:, :] subtract_1_keys = positions[1:, :] subtract_1_values = positions[:-1, :] l = int(positions.shape[0]) // 2 add_keys = np.array([np.concatenate([positions[i, :], positions[j, :]]) for i in range(l) for j in range(l)]) add_values = np.array([positions[i + j, :] for i in range(l) for j in range(l)]) # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" sub_keys = np.array([np.concatenate([positions[i, :], positions[j, :]]) for j in range(l) for i in range(l)]) sub_values = np.array([positions[max(i - j, 0), :] for j in range(l) for i in range(l)]) query_types = [ QueryPositionKV(), QueryPositionKV(keys=succ_keys, values=succ_values), QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values), QueryPositionKV(keys=add_keys, values=add_values, binary=True), QueryPositionKV(keys=sub_keys, values=sub_values, binary=True)] return [qt @ pos for qt in query_types] # pylint: disable=syntax-error
def _train_step(self, next_train_batch): """Run one training step and update self._opt_state.""" # Calculate the current optimizer parameters. # TODO(pkozakowski): Optimizer parameters get polluted with model state, # which doesn't break anything but is weird. Filter it out. opt_param_updates = layers.nested_map( lambda x: self._maybe_replicate(np.array(x)), self.nontrainable_params) opt_state = self._opt_state opt_state.opt_params.update(opt_param_updates) # Run the update. (weights, slots), self._model_state, self._rngs = self._jit_update_fn( self._step, opt_state, next_train_batch, self._model_state, self._rngs) self._model_state = self._map_to_state_dicts(self._state_dicts_update) self._opt_state = opt_state._replace(weights=weights, slots=slots) self._step += 1
def __init__(self, learning_rate, **init_opt_params): """Initialize the optimizer. Takes the initial optimizer parameters as positional arguments. They are fed back to the optimizer in tree_update, in the same order. They can be changed between updates, e.g. for learning rate schedules. The constructor should be overridden in derived classes to give names to the optimizer parameters, so the gin configuration can set them. Args: learning_rate: The initial learning rate. **init_opt_params: Initial values of any additional optimizer parameters. """ init_opt_params['learning_rate'] = learning_rate self._init_opt_params = { name: np.array(value) for (name, value) in init_opt_params.items() }
def new_params_and_state(self, input_shape, input_dtype, rng): del input_shape, input_dtype, rng params = () state = {self._name: np.array(self._initial_rate)} return params, state
def PolicySchedule( history, observation_metrics=( ("train", "metrics/accuracy"), ("train", "metrics/loss"), ("eval", "metrics/accuracy"), ("eval", "metrics/loss"), ), include_controls_in_observation=False, control_configs=( # (name, start, (low, high), flip) ("learning_rate", 1e-3, (1e-9, 10.0), False), ), observation_range=(0.0, 10.0), action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_and_value_vocab_size=None, policy_dir=gin.REQUIRED, temperature=1.0, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_controls_in_observation: bool, whether to include the controls in observations. control_configs: control configs, see trax.rl.envs.OnlineTuneEnv. observation_range: tuple (low, high), range to clip the metrics to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_and_value_vocab_size: vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. policy_dir: directory with the policy checkpoint. temperature: temperature for sampling from the policy. Returns: a function nontrainable_params(step): float -> {"name": float}, the step-dependent schedule for nontrainable parameters. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() observations = online_tune.history_to_observations( history, observation_metrics, observation_range, control_configs if include_controls_in_observation else None) logging.vlog(1, "Building observations took %0.2f sec.", time.time() - start_time) if observations.shape[0] == 0: controls = { name: start_value for (name, start_value, _, _) in control_configs } return lambda _: controls assert policy_and_value_vocab_size is None, ( "Serialized policies are not supported yet.") # Build the policy network and load its parameters. start_time = time.time() net = ppo.policy_and_value_net( n_controls=len(control_configs), n_actions=len(action_multipliers), vocab_size=policy_and_value_vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) logging.vlog(1, "Building the policy network took %0.2f sec.", time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step) (opt_state, state, _, _) = ppo.maybe_restore_opt_state(policy_dir) assert opt_state is not None, "Policy checkpoint not found." (params, _) = opt_state logging.vlog(1, "Restoring the policy parameters took %0.2f sec.", time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() # ((log_probs, value_preds), state). We have no way to pass state to the next # step, but that should be fine. (log_probs, _) = (net(np.array([observations]), params=params, state=state, rng=rng)) logging.vlog(1, "Running the policy took %0.2f sec.", time.time() - start_time) # Sample from the action distribution for the last timestep. assert log_probs.shape == (1, len(control_configs) * observations.shape[0], len(action_multipliers)) action = utils.gumbel_sample(log_probs[0, -len(control_configs):, :] / temperature) # Get new controls. controls = { # name: value control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension control_config, control_action, history, action_multipliers) for (control_action, control_config) in zip(action, control_configs) } return lambda _: controls
def new_weights_and_state(self, input_signature): del input_signature state = {self._name: np.array(self._initial_rate)} return base.EMPTY_WEIGHTS, state
def _decay_rate_pow(i, exponent=0.8): """Default Adafactor second-moment decay schedule.""" t = np.array(i, np.float32) + 1.0 return 1.0 - t**(-exponent)
def PolicySchedule( history, observation_metrics=( ('train', 'metrics/accuracy'), ('train', 'metrics/loss'), ('eval', 'metrics/accuracy'), ('eval', 'metrics/loss'), ), include_controls_in_observation=False, control_configs=( # (name, start, (low, high), flip) ('learning_rate', 1e-3, (1e-9, 10.0), False), ), observation_range=(0.0, 10.0), action_multipliers=(1.0 / 1.5, 1.0 / 1.25, 1.0, 1.25, 1.5), policy_and_value_model=trax_models.FrameStackMLP, policy_and_value_two_towers=False, policy_and_value_vocab_size=None, policy_dir=gin.REQUIRED, temperature=1.0, ): """Learning rate schedule controlled by a learned policy. Args: history: the history of training and evaluation (History object). observation_metrics: list of pairs (mode, metric), as in the History object. include_controls_in_observation: bool, whether to include the controls in observations. control_configs: control configs, see trax.rl.envs.OnlineTuneEnv. observation_range: tuple (low, high), range to clip the metrics to. action_multipliers: sequence of LR multipliers that policy actions correspond to. policy_and_value_model: Trax model to use as the policy. policy_and_value_two_towers: bool, whether the action distribution and value prediction is computed by separate model towers. policy_and_value_vocab_size: vocabulary size of a policy and value network operating on serialized representation. If None, use raw continuous representation. policy_dir: directory with the policy checkpoint. temperature: temperature for sampling from the policy. Returns: a function nontrainable_params(step): float -> {'name': float}, the step-dependent schedule for nontrainable parameters. """ # Turn the history into observations for the policy. If we don't have any, # return the initial learning rate. start_time = time.time() observations = online_tune.history_to_observations( history, observation_metrics, observation_range, control_configs if include_controls_in_observation else None) logging.vlog(1, 'Building observations took %0.2f sec.', time.time() - start_time) if observations.shape[0] == 0: controls = { name: start_value for (name, start_value, _, _) in control_configs } return lambda _: controls # Build the policy network and load its parameters. start_time = time.time() net = ppo.policy_and_value_net( n_controls=len(control_configs), n_actions=len(action_multipliers), vocab_size=policy_and_value_vocab_size, bottom_layers_fn=policy_and_value_model, two_towers=policy_and_value_two_towers, ) logging.vlog(1, 'Building the policy network took %0.2f sec.', time.time() - start_time) start_time = time.time() # (opt_state, state, epoch, opt_step, history) (opt_state, state, _, _, _) = ppo.maybe_restore_opt_state(policy_dir) assert opt_state is not None, 'Policy checkpoint not found.' (params, _, _) = opt_state logging.vlog(1, 'Restoring the policy parameters took %0.2f sec.', time.time() - start_time) # Run the policy and sample an action. seed = random.randint(0, 2**31 - 1) rng = jax_random.get_prng(seed=seed) start_time = time.time() (low, high) = observation_range observation_space = gym.spaces.Box(shape=observations.shape[1:], low=low, high=high) action_space = gym.spaces.MultiDiscrete(nvec=(len(action_multipliers), ) * len(control_configs)) n_timesteps = observations.shape[0] rewards_to_actions = ppo.init_rewards_to_actions( policy_and_value_vocab_size, observation_space, action_space, n_timesteps) # (log_probs, value_preds, state, rng) (log_probs, _, _, _) = ppo.run_policy( policy_and_value_net_apply=net, observations=np.array([observations]), lengths=np.array([n_timesteps]), weights=params, state=state, rng=rng, vocab_size=policy_and_value_vocab_size, observation_space=observation_space, action_space=action_space, rewards_to_actions=rewards_to_actions, ) logging.vlog(1, 'Running the policy took %0.2f sec.', time.time() - start_time) # Sample from the action distribution for the last timestep. assert log_probs.shape == (1, len(control_configs), len(action_multipliers)) action = utils.gumbel_sample(log_probs[0] / temperature) # Get new controls. controls = { # name: value control_config[0]: online_tune.update_control( # pylint: disable=g-complex-comprehension control_config, control_action, history, action_multipliers) for (control_action, control_config) in zip(action, control_configs) } return lambda _: controls
def _state_dicts_update(self, state_dict): assert len(state_dict.keys()) == 1 key = list(state_dict.keys())[0] value = np.array(state_dict[key]) return {key: np.array(self.update_model_state(key, value))}