def PerformPositionOperations(pos, positions=None): """Gets pos and returns (q1, ..., q5).""" succ_keys = positions[:-1, :] succ_values = positions[1:, :] subtract_1_keys = positions[1:, :] subtract_1_values = positions[:-1, :] l = int(positions.shape[0]) // 2 add_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for i in range(l) for j in range(l) ]) add_values = np.array( [positions[i + j, :] for i in range(l) for j in range(l)]) # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" sub_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for j in range(l) for i in range(l) ]) sub_values = np.array( [positions[max(i - j, 0), :] for j in range(l) for i in range(l)]) query_types = [ QueryPositionKV(), QueryPositionKV(keys=succ_keys, values=succ_values), QueryPositionKV(keys=subtract_1_keys, values=subtract_1_values), QueryPositionKV(keys=add_keys, values=add_values, binary=True), QueryPositionKV(keys=sub_keys, values=sub_values, binary=True) ] return [qt @ pos for qt in query_types] # pylint: disable=syntax-error
def PPOJointLoss(x, **unused_kwargs): """Definition of the Proximal Policy Optimization loss.""" dist_inputs, values, returns, actions, old_log_probs, mask = x del mask # TODO(lukaszkaiser): make PPO work with Transformer new_log_probs = self._policy_dist.log_prob(dist_inputs, actions) advantages = returns - values l2_value_loss = jnp.sum(advantages**2) * self._value_loss_coeff # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) unclipped_objective = probs_ratio * advantages clipped_objective = jnp.clip(probs_ratio, 1 - self._epsilon, 1 + self._epsilon) * advantages ppo_objective = jnp.minimum(unclipped_objective, clipped_objective) entropy_loss = self._policy_dist.entropy(new_log_probs) *\ self._entropy_coeff return -ppo_objective.mean() + l2_value_loss - entropy_loss
def test_concatenate(self): x0 = np.array([[1, 2, 3], [4, 5, 6]]) x1 = np.array([[10, 20, 30], [40, 50, 60]]) layer0 = cb.Concatenate(axis=0) y = layer0([x0, x1]) self.assertEqual(y.tolist(), [[1, 2, 3], [4, 5, 6], [10, 20, 30], [40, 50, 60]]) layer1 = cb.Concatenate(axis=1) y = layer1([x0, x1]) self.assertEqual(y.tolist(), [[1, 2, 3, 10, 20, 30], [4, 5, 6, 40, 50, 60]]) layer2 = cb.Concatenate(n_items=3) y = layer2([x0, x1, x0]) self.assertEqual(y.tolist(), [[1, 2, 3, 10, 20, 30, 1, 2, 3], [4, 5, 6, 40, 50, 60, 4, 5, 6]]) self.assertEqual(repr(layer0), 'Concatenate_axis0_in2') self.assertEqual(repr(layer1), 'Concatenate_axis1_in2') self.assertEqual(repr(layer2), 'Concatenate_in3')
def test_fn_layer_example(self): layer = base.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0))) input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7))) expected_shape = ((2, 7), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([2]), np.array([3])) x, xs = layer(inp) self.assertEqual(int(x), 5) self.assertEqual([int(y) for y in xs], [2, 3])
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs): """Query a table with a position vector.""" if keys is None: return x k = np.array(keys) v = np.array(values) q = x if binary: q = np.concatenate([x, x], axis=-1) return tl.DotProductAttention(q, k, v, None, 0.0, None, None)
def ProbsRatioMean(x, **unused_kwargs): """Probability Ratio Mean from the PPO algorithm.""" dist_inputs, _, _, actions, old_log_probs = x new_log_probs = self._policy_dist.log_prob(dist_inputs, actions) # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1)) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) return jnp.mean(probs_ratio)
def test_scan_basic(self): @base.layer(n_in=2, n_out=2) def add(x, **unused_kwargs): res = x[0] + x[1] return res, res scan_layer = cb.Scan(add()) # pylint: disable=no-value-for-parameter input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7))) expected_shape = ((3, 2, 7), (2, 7)) output_shape = base.check_shape_agreement(scan_layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([1, 2, 3]), np.array(0)) o, v = scan_layer(inp) self.assertEqual(int(v), 6) self.assertEqual([int(x) for x in o], [1, 3, 6])
def one_hot(x, size, dtype=np.float32): # pylint: disable=invalid-name """Make a n+1 dim one-hot array from n dim int-categorical array.""" arange_size = np.arange(size) if math.backend_name() == 'jax': # Work around a jax broadcasting issue. arange_size = jax.lax.tie_in(x, arange_size) return np.array(x[..., np.newaxis] == arange_size, dtype)
def _multi_device_put(x, devices=None): """Memory efficient multi-device replication / broadcast in JAX. JAX uses a ShardedDeviceArray class that holds a list of device buffers on separate devices for use with pmap'd computations. Sharded arrays are explicitly used to eliminate unnecessary inter-device transfer of memory buffers between use in pmap'd computations. The JAX API currently does not have a multi-device 'put' function that copies a buffer onto N devices in a memory-efficient fashion, so we implement our own here. Args: x: jax DeviceArray or numpy ndarray to be replicated. devices: a jax.devices() list or subset thereof of devices to replicate onto. Should match the list passed to any pmaps ingesting the replicated array. Returns: A ShardedDeviceArray with dtype = x.dtype and shape = (n_devices,) + x.shape that's backed by replicated device_buffers on each local device. """ # Convert _FilledConstants that don't have device_buffer, etc. if type(x) != jax.xla.DeviceArray: # pylint: disable=unidiomatic-typecheck x = np.array(x) # Calculate the abstract shape of the replicated array. if not devices: devices = jax.local_devices() n_devices = len(devices) x_aval = jax.xla.abstractify(x) broadcast_x_aval = jax.abstract_arrays.ShapedArray( (n_devices, ) + x_aval.shape, x_aval.dtype) # Create copies of the underlying device buffer for each local device. broadcast_buffers = [jax.device_put(x, dv).device_buffer for dv in devices] return jax.pxla.ShardedDeviceArray(broadcast_x_aval, broadcast_buffers)
def NewPositionalEncoding(x, positions=None, **kwargs): """Implements new positional encoding.""" del kwargs x_length = np.shape(x)[1] pos = np.array(positions)[np.newaxis, :x_length, :] pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch. return pos
def update_model_state(self, key, value): """Updates model state based on nontrainable_params.""" p_name = key if p_name in self.nontrainable_params: return self._for_n_devices( np.array(self.nontrainable_params[p_name])) return value
def test_scan_basic(self): def add(): # pylint: disable=invalid-name def f(x, carry): res = x + carry return res, res # output and carry are the same return base.Fn('add', f, n_out=2) scan_layer = cb.Scan(add()) input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7))) expected_shape = ((3, 2, 7), (2, 7)) output_shape = base.check_shape_agreement(scan_layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([1, 2, 3]), np.array(0)) o, v = scan_layer(inp) self.assertEqual(int(v), 6) self.assertEqual([int(x) for x in o], [1, 3, 6])
def __init__(self, learning_rate, clip_grad_norm=None, **init_opt_params): """Sets initial hyperparameter values for this optimizer. Takes initial optimizer parameters as keyword arguments. These values can be changed between training steps, e.g., for learning rate schedules. If you want your subclass to expose hyperparameters for gin configuration, override this constructor and use explicitly named keyword arguments. See `momentum.Momentum.__init__` for one such example. Args: learning_rate: The initial learning rate. clip_grad_norm: float; the value to which gradients will be clipped. **init_opt_params: Initial values of any additional optimizer parameters. """ init_opt_params['learning_rate'] = learning_rate self._init_opt_params = { name: np.array(value) for (name, value) in init_opt_params.items() } self._slots = None # Gradient clipping happens with respect to the norm of the whole gradient # tree, so it is not passed to single-slot updates, but done in this class # for the whole gradient tree. self._clip_grad_norm = clip_grad_norm
def one_hot(x, n_categories, dtype=np.float32): # pylint: disable=invalid-name """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" indices_less_than_n = np.arange(n_categories) if math.backend_name() == 'jax': # Work around a jax broadcasting issue. indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n) return np.array(x[..., np.newaxis] == indices_less_than_n, dtype)
def EntropyLoss(x, **unused_kwargs): """Definition of the Entropy Layer.""" dist_inputs, _, _, actions = x new_log_probs = self._policy_dist.log_prob(dist_inputs, actions) new_log_probs = jnp.array(new_log_probs.squeeze(axis=-1)) entropy_loss = self._policy_dist.entropy(new_log_probs) *\ self._entropy_coeff return entropy_loss
def test_pure_layer_value_forward(self): layer = base.PureLayer(lambda x: 2 * x) # Use Layer.__call__. in_0 = np.array([1, 2]) out_0 = layer(in_0) self.assertEqual(out_0.tolist(), [2, 4]) # Use PureLayer.forward. in_1 = np.array([3, 4]) out_1 = layer.forward(in_1, base.EMPTY_WEIGHTS) self.assertEqual(out_1.tolist(), [6, 8]) # Use Layer.forward_with_state. in_2 = np.array([5, 6]) out_2, _ = layer.forward_with_state(in_2) self.assertEqual(out_2.tolist(), [10, 12])
def test_from_file(self): params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) filename = self.create_tempfile('params.npy').full_path with open(filename, 'wb') as f: np.save(f, params) initializer = initializers.InitializerFromFile(filename) input_shape = (3, 2) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual('%s' % init_value, '%s' % params)
def Fn(f, n_in=None, n_out=None): # pylint: disable=invalid-name """Returns a layer with no weights that applies the function f. The function f can take and return any number of arguments, but it cannot have default arguments or keywords arguments. It can use numpy though, e.g: A layer that takes 2 arguments and returns sum and concatenation on stack: Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0))) Sometimes determining the number of outputs automatically fails, in such cases specify n_in and n_out. Args: f: the function to execute n_in: optional, number of inputs n_out: optional, number of outputs Returns: A layer executing the function f. """ # Inspect the function f to restrict to no-defaults and no-kwargs functions. if six.PY2: argspec = inspect.getargspec(f) varkwargs = argspec.keywords else: argspec = inspect.getfullargspec(f) varkwargs = argspec.varkw # This layer cannot handle functions with kwargs or defaults. if argspec.defaults is not None: raise ValueError('function cannot have default arguments') if varkwargs: raise ValueError('function cannot have keyword arguments') # Determine n_in from function signature if not set. if n_in is None: if argspec.varargs is not None: raise ValueError('n_in is not set and f has variable args') n_in = len(argspec.args) # Try to determine n_out from function signature. if n_out is None: try: dummy_args = [np.array([[0.0]]) for _ in range(n_in)] res = f(*dummy_args) n_out = len(res) if isinstance(res, (list, tuple)) else 1 except: raise ValueError('n_out is not set and could not be determined') # Create the layer. @layer(n_in=n_in, n_out=n_out) def F(xs, **unused_kwargs): # pylint: disable=invalid-name if not isinstance(xs, (tuple, list)): xs = (xs, ) return f(*xs) return F() # pylint: disable=no-value-for-parameter
def ProbsRatio(dist_inputs, actions, old_log_probs, log_prob_fun): """Probability Ratio from the PPO algorithm.""" # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion probs_ratio = jnp.exp(new_log_probs - old_log_probs) return probs_ratio
def ApproximateKLDivergence(dist_inputs, actions, old_log_probs, log_prob_fun): """Probability Ratio from the PPO algorithm.""" # TODO(henrykm): Clarify the old_log_probs and squeezing # Old log probs have an undesirable extra dimension which we remove here old_log_probs = jnp.array(old_log_probs.squeeze(axis=-1), dtype=jnp.float32) new_log_probs = NewLogProbs(dist_inputs, actions, log_prob_fun) # The ratio between new_probs and old_probs expressed # using log_probs and exponentaion approximate_kl_divergence = 0.5 * \ jnp.mean(new_log_probs - old_log_probs) ** 2 return approximate_kl_divergence
def test_from_file(self): params = np.array([[0.0, 0.1], [0.2, 0.3], [0.4, 0.5]]) # `create_tempfile` needs access to --test_tmpdir, however in the OSS world # pytest doesn't run `absltest.main`, so we need to manually parse the flags test_utils.ensure_flag('test_tmpdir') filename = self.create_tempfile('params.npy').full_path with open(filename, 'wb') as f: np.save(f, params) initializer = initializers.InitializerFromFile(filename) input_shape = (3, 2) init_value = initializer(input_shape, random.get_prng(0)) self.assertEqual('%s' % init_value, '%s' % params)
def test_scan_nocarry(self): def addone(): # pylint: disable=invalid-name return base.Fn('addone', lambda x: x + 1) scan_layer = cb.Scan(addone(), n_carry=0) input_signature = ShapeDtype((3, 2, 7)) expected_shape = (3, 2, 7) output_shape = base.check_shape_agreement(scan_layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = np.array([1, 2, 3]) o = scan_layer(inp) self.assertEqual([int(x) for x in o], [2, 3, 4])
def __init__(self, mode=None, learn_epsilon=False, init_epsilon=1e-6, init_learnt_epsilon=1e-4): super(FilterResponseNorm, self).__init__() del mode # If we learn epsilon then epsilon = init_epsilon + |learnt_value| # where learnt_value is initialized to init_learnt_epsilon. # If learn_epsilon is false then epsilon is just init_epsilon. # # NOTE: I (afrozm) haven't been able to train with `learn_epsilon = True`. self._learn_epsilon = learn_epsilon assert init_epsilon > 0 assert init_learnt_epsilon > 0 self._init_epsilon = np.array(init_epsilon, dtype=np.float32) self._init_learnt_epsilon = np.array(init_learnt_epsilon, dtype=np.float32)
def new_weights_and_state(self, input_signature): d_feature = input_signature.shape[-1] pe = onp.zeros((self._max_len, d_feature), dtype=onp.float32) position = onp.arange(0, self._max_len)[:, onp.newaxis] div_term = onp.exp( onp.arange(0, d_feature, 2) * -(onp.log(10000.0) / d_feature)) pe[:, 0::2] = onp.sin(position * div_term) pe[:, 1::2] = onp.cos(position * div_term) pe = pe[onp.newaxis, :, :] # [1, self._max_len, d_feature] weights = np.array( pe) # These are trainable parameters, initialized above. state = 0 if self._mode == 'predict' else base.EMPTY_STATE return weights, state
def test_scan_nocarry(self): @base.layer(n_in=1, n_out=1) def addone(x, **unused_kwargs): return x + 1 scan_layer = cb.Scan(addone(), n_carry=0) # pylint: disable=no-value-for-parameter input_signature = ShapeDtype((3, 2, 7)) expected_shape = (3, 2, 7) output_shape = base.check_shape_agreement(scan_layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = np.array([1, 2, 3]) o = scan_layer(inp) self.assertEqual([int(x) for x in o], [2, 3, 4])
def update_model_state(self, key, value): """Updates model state based on nontrainable_params.""" # Translate model state keys to nontrainable param names. if key in self._nontrainable_param_map: p_name = self._nontrainable_param_map[key] else: # If a key is not in mapping, it stays the same. p_name = key if p_name in self.nontrainable_params: if self._step == 0: log('Mapping model state key {} to nontrainable param {}.' ''.format(key, p_name)) return self._for_n_devices(np.array(self.nontrainable_params[p_name])) return value
def new_weights(self, input_signature): d_feature = input_signature.shape[-1] pe = np.zeros((self._max_len, d_feature), dtype=np.float32) position = np.arange(0, self._max_len)[:, np.newaxis] div_term = np.exp( np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, self._max_len, d_feature] weights = jnp.array(pe) # Trainable parameters, initialized above. if self._mode == 'predict': batch_size = input_signature.shape[0] self.state = jnp.zeros((batch_size,), dtype=jnp.int32) return weights
def update(self, step, grads, weights, slots, opt_params): m, v = slots learning_rate = opt_params['learning_rate'] weight_decay_rate = opt_params['weight_decay_rate'] b1 = opt_params['b1'] b2 = opt_params['b2'] eps = opt_params['eps'] step = np.array(step).astype( np.int32) # Make sure it's the right type. m = (1 - b1) * grads + b1 * m # First moment estimate. v = (1 - b2) * (grads**2) + b2 * v # Second moment estimate. mhat = m / (1 - b1**(step + 1)) # Bias correction. vhat = v / (1 - b2**(step + 1)) new_weights = (1 - weight_decay_rate) * weights - ( learning_rate * mhat / (np.sqrt(vhat) + eps)).astype(weights.dtype) return new_weights, (m, v)
def mapped_update(i, opt_state, batch, state, rng): """This is a multi-device version of the update function above.""" # We assume all tensors have the first dimension = n_devices. weights, slots, opt_params = opt_state rng, subrng = jax_random.split(rng) grad_fn = math.grad(model_and_loss_call, has_aux=True) grads, state = grad_fn(weights, batch, state, rng) # We do a psum(1.0) here instead of `n_devices` since `n_devices` is just # the number of devices on this host machine, however psum goes over all # devices of all hosts (ex: a TPU pod) and we need to be averaging over all # of them. grads = jax.tree_util.tree_map( lambda g: math.psum(g, 'batch') / math.psum( np.array(1.0), 'batch'), grads) return optimizer.tree_update(i, grads, weights, slots, opt_params), state, subrng
def __init__(self, learning_rate, **init_opt_params): """Initialize the optimizer. Takes the initial optimizer parameters as positional arguments. They are fed back to the optimizer in tree_update, in the same order. They can be changed between updates, e.g. for learning rate schedules. The constructor should be overridden in derived classes to give names to the optimizer parameters, so the gin configuration can set them. Args: learning_rate: The initial learning rate. **init_opt_params: Initial values of any additional optimizer parameters. """ init_opt_params['learning_rate'] = learning_rate self._init_opt_params = { name: np.array(value) for (name, value) in init_opt_params.items() }