def GeneralGRUCell(candidate_transform, memory_transform_fn=None, gate_nonlinearity=activation_fns.Sigmoid, candidate_nonlinearity=activation_fns.Tanh, dropout_rate_c=0.1, sigmoid_bias=0.5): r"""Parametrized Gated Recurrent Unit (GRU) cell construction. GRU update equations: $$ Update gate: u_t = \sigmoid(U' * s_{t-1} + B') $$ $$ Reset gate: r_t = \sigmoid(U'' * s_{t-1} + B'') $$ $$ Candidate memory: c_t = \tanh(U * (r_t \odot s_{t-1}) + B) $$ $$ New State: s_t = u_t \odot s_{t-1} + (1 - u_t) \odot c_t $$ See combinators.Gate for details on the gating function. Args: candidate_transform: Transform to apply inside the Candidate branch. Applied before nonlinearities. memory_transform_fn: Optional transformation on the memory before gating. gate_nonlinearity: Function to use as gate activation. Allows trying alternatives to Sigmoid, such as HardSigmoid. candidate_nonlinearity: Nonlinearity to apply after candidate branch. Allows trying alternatives to traditional Tanh, such as HardTanh dropout_rate_c: Amount of dropout on the transform (c) gate. Dropout works best in a GRU when applied exclusively to this branch. sigmoid_bias: Constant to add before sigmoid gates. Generally want to start off with a positive bias. Returns: A model representing a GRU cell with specified transforms. """ gate_block = [ # u_t candidate_transform(), base.Fn(lambda x: x + sigmoid_bias), gate_nonlinearity(), ] reset_block = [ # r_t candidate_transform(), base.Fn(lambda x: x + sigmoid_bias), # Want bias to start positive. gate_nonlinearity(), ] candidate_block = [ cb.Dup(), reset_block, cb.Multiply(), # Gate S{t-1} with sigmoid(candidate_transform(S{t-1})) candidate_transform(), # Final projection + tanh to get Ct candidate_nonlinearity(), # Candidate gate # Only apply dropout on the C gate. Paper reports 0.1 as a good default. core.Dropout(rate=dropout_rate_c) ] memory_transform = memory_transform_fn() if memory_transform_fn else [] return cb.Serial( cb.Branch(memory_transform, gate_block, candidate_block), cb.Gate(), )
def test_fn_layer_varargs_n_in(self): with self.assertRaisesRegex(ValueError, 'variable arg'): base.Fn(lambda *args: args[0]) # Check that varargs work when n_in is set. id_layer = base.Fn(lambda *args: args[0], n_in=1) input_signature = ShapeDtype((2, 7)) expected_shape = (2, 7) output_shape = base.check_shape_agreement(id_layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_fn_layer_difficult_n_out(self): with self.assertRaisesRegex(ValueError, 'n_out'): # Determining the output of this layer is hard with dummies. base.Fn(lambda x: np.concatencate([x, x], axis=4)) # Check that this layer works when n_out is set. layer = base.Fn(lambda x: np.concatenate([x, x], axis=4), n_out=1) input_signature = ShapeDtype((2, 1, 2, 2, 3)) expected_shape = (2, 1, 2, 2, 6) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape)
def test_output_signature(self): input_signature = (ShapeDtype((2, 3, 5)), ShapeDtype((2, 3, 5))) layer = base.Fn(lambda x, y: x + y) # n_in = 2, n_out = 1 output_signature = layer.output_signature(input_signature) self.assertEqual(output_signature, ShapeDtype((2, 3, 5))) input_signature = ShapeDtype((5, 7)) layer = base.Fn(lambda x: (x, 2 * x, 3 * x)) # n_in = 1, n_out = 3 output_signature = layer.output_signature(input_signature) self.assertEqual(output_signature, (ShapeDtype((5, 7)), ) * 3) self.assertNotEqual(output_signature, (ShapeDtype((4, 7)), ) * 3) self.assertNotEqual(output_signature, (ShapeDtype((5, 7)), ) * 2)
def test_output_signature(self): input_signature = (shapes.ShapeDtype((2, 3, 5)), shapes.ShapeDtype((2, 3, 5))) layer = base.Fn('2in1out', lambda x, y: x + y) output_signature = layer.output_signature(input_signature) self.assertEqual(output_signature, shapes.ShapeDtype((2, 3, 5))) input_signature = shapes.ShapeDtype((5, 7)) layer = base.Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3) output_signature = layer.output_signature(input_signature) self.assertEqual(output_signature, (shapes.ShapeDtype((5, 7)),) * 3) self.assertNotEqual(output_signature, (shapes.ShapeDtype((4, 7)),) * 3) self.assertNotEqual(output_signature, (shapes.ShapeDtype((5, 7)),) * 2)
def test_weights_state(self): layer = base.Fn( '2in2out', lambda x, y: (x + y, jnp.concatenate([x, y], axis=0)), n_out=2) weights, state = layer.new_weights_and_state(None) self.assertEmpty(weights) self.assertEmpty(state)
def MakeZeroState(depth_multiplier=1): """Makes zeros of shape like x but removing the length (axis 1).""" def f(x): # pylint: disable=invalid-name assert len(x.shape) == 3, 'Expecting x of shape [batch, length, depth].' return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]), dtype=jnp.float32) return base.Fn('MakeZeroState', f)
def Bidirectional(forward_layer, axis=1, merge_layer=Concatenate()): """Bidirectional combinator for RNNs. Args: forward_layer: A layer, such as `trax.layers.LSTM` or `trax.layers.GRU`. axis: a time axis of the inputs. Default value is `1`. merge_layer: A combinator used to combine outputs of the forward and backward RNNs. Default value is 'trax.layers.Concatenate'. Example: Bidirectional(RNN(n_units=8)) Returns: The Bidirectional combinator for RNNs. """ backward_layer = copy.deepcopy(forward_layer) flip = base.Fn('_FlipAlongTimeAxis', lambda x: jnp.flip(x, axis=axis)) backward = Serial( flip, backward_layer, flip, ) return Serial( Branch(forward_layer, backward), merge_layer, )
def WeightedFScore(beta=1., initial_category_index=0): """Returns a layer that computes a weighted F-score. The weighted F-score summarize how well the classifier's `k` predictions align with the observed/gold instances of `k`. It additionally weights the summary by the number of observed/gold and predicted examples in each class. Args: beta: a parameter that determines the weight of recall in the F-score. initial_category_index: an index of the initial category. The layer takes two inputs: - Model output from one batch, an ndarray of float-valued elements. - A batch of element-wise target values, which matches the shape of the model output. The layer returns a weighted F-score across all the classes. """ def f(model_output, targets): # pylint: disable=invalid-name beta2 = beta**2 predictions = jnp.argmax(model_output, axis=-1) n_categories = model_output.shape[-1] f_scores = jnp.empty(0) weights = jnp.empty(0) for k in range(initial_category_index, n_categories): _, _, n_k_targets, precision, recall = _precision_recall( predictions, targets, k) f_scores = jnp.append(f_scores, _f_score(precision, recall, beta2)) weights = jnp.append(weights, n_k_targets) return jnp.average(f_scores, weights=weights) return base.Fn('WeightedFScore', f)
def CategoryAccuracy(): r"""Returns a layer that computes category prediction accuracy. The layer takes two inputs: - A batch of activation vectors. The components in a given vector should be mappable to a probability distribution in the following loose sense: within a vector, a higher component value corresponds to a higher probability, such that argmax within a vector (``axis=-1``) picks the index (category) having the highest probablity. - A batch of target categories; each target is an integer in :math:`\{0, ..., N-1\}`. The predicted category from each vector is the index of the highest-valued vector component. The layer returns the accuracy of these predictions averaged over the batch. """ def f(model_output, targets): # pylint: disable=invalid-name predictions = jnp.argmax(model_output, axis=-1) shapes.assert_same_shape(predictions, targets) n_total = predictions.size n_correct = jnp.sum(jnp.equal(predictions, targets)) return n_correct / n_total return base.Fn('CategoryAccuracy', f)
def SmoothL1Loss(): r"""Returns a layer that computes a weighted, smoothed L1 loss for one batch. The layer takes three inputs: - Model output from one batch, an ndarray of float-valued elements. - A batch of element-wise target values, which matches the shape of the model output. - A batch of weights, which matches the shape of the model output. The layer computes a "smooth" L1 loss (a.k.a. Huber loss), for model output float :math:`y_i` and target float :math:`t_i`: .. math:: \text{output} = \left\{ \begin{array}{cl} \frac 1 2 (y_i - t_i)^2, & \text{if}\ |y_i - t_i| < 1, \\ |y_i - t_i| - \frac 1 2, & \text{otherwise}. \end{array} \right. The layer returns a weighted average of these element-wise values. """ def f(model_output, targets, weights): # pylint: disable=invalid-name shapes.assert_same_shape(model_output, targets) shapes.assert_same_shape(model_output, weights) l1_dist = jnp.abs(model_output - targets) smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5) weighted_smooth_dist = weights * smooth_dist return jnp.sum(weighted_smooth_dist) / jnp.sum(weights) return base.Fn('SmoothL1Loss', f)
def _Accuracy(): """Returns a layer that scores predicted versus target category.""" def f(predicted_category, target_category): # pylint: disable=invalid-name # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. # shapes.assert_same_shape(predicted_category, target_category) return jnp.equal(predicted_category, target_category).astype(jnp.float32) return base.Fn('_Accuracy', f)
def CategoryCrossEntropy(): """Returns a layer that computes cross entropy from activations and integers. The layer takes two inputs: - A batch of activation vectors. The components in a given vector should be pre-softmax activations (mappable to a probability distribution via softmax). For performance reasons, the softmax and cross entropy computations are combined inside the layer. - A batch of target categories; each target is an integer in `{0, ..., N-1}`, where `N` is the activation vector depth/dimensionality. To compute cross-entropy, the layer derives probability distributions from its inputs: - activation vectors: vector --> SoftMax(vector) - target categories: integer --> OneHot(integer) (The conversion of integer category targets to one-hot vectors amounts to assigning all the probability mass to the target category.) Cross-entropy per batch item is computed between the resulting distributions; notionally: cross_entropy(one_hot(targets), softmax(model_output)) The layer returns the average of these cross-entropy values over all items in the batch. """ def f(model_output, targets): # pylint: disable=invalid-name cross_entropies = _category_cross_entropy(model_output, targets) return jnp.average(cross_entropies) return base.Fn('CategoryCrossEntropy', f)
def WeightedCategoryCrossEntropy(): r"""Returns a layer like ``CategoryCrossEntropy``, with weights as third input. The layer takes three inputs: - A batch of activation vectors. The components in a given vector should be pre-softmax activations (mappable to a probability distribution via softmax). For performance reasons, the softmax and cross-entropy computations are combined inside the layer. - A batch of target categories; each target is an integer in :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector depth/dimensionality. - A batch of weights, which matches or can be broadcast to match the shape of the target ndarray. This arg can give uneven weighting to different items in the batch (depending, for instance, on the item's target category). The layer returns the weighted average of these cross-entropy values over all items in the batch. """ def f(model_output, targets, weights): # pylint: disable=invalid-name cross_entropies = _category_cross_entropy(model_output, targets) return jnp.sum(cross_entropies * weights) / jnp.sum(weights) return base.Fn('WeightedCategoryCrossEntropy', f)
def SmoothL1Loss(): """Returns a layer that computes total smooth L1 loss for one batch.""" def smoothl1loss(model_output, targets, weights): # pylint: disable=invalid-name r"""Returns weighted smooth L1 norm of `model_output - targets`. The smooth L1 loss, also known as the Huber loss, is defined as: .. math:: z_i = \begin{cases} 0.5 (x_i - y_i)^2, & \text{if } |x_i - y_i| < 1 \\ |x_i - y_i| - 0.5, & \text{otherwise } \end{cases} Args: model_output: Output from one batch, treated as an unanalyzed tensor. targets: Tensor of same shape as `model_output` containing element-wise target values. weights: Tensor of same shape as `model_output` and `targets`, containing element-wise weight values. """ shapes.assert_same_shape(model_output, targets) shapes.assert_same_shape(targets, weights) l1_dist = jnp.abs(model_output - targets) smooth_dist = jnp.where(l1_dist < 1, 0.5 * l1_dist**2, l1_dist - 0.5) shapes.assert_same_shape(smooth_dist, weights) weighted_smooth_dist = weights * smooth_dist return jnp.sum(weighted_smooth_dist) / jnp.sum(weights) return base.Fn('SmoothL1Loss', smoothl1loss)
def WeightedCategoryAccuracy(): r"""Returns a layer that computes a weighted category prediction accuracy. The layer takes three inputs: - A batch of activation vectors. The components in a given vector should be mappable to a probability distribution in the following loose sense: within a vector, a higher component value corresponds to a higher probability, such that argmax within a vector (``axis=-1``) picks the index (category) having the highest probablity. - A batch of target categories; each target is an integer in :math:`\{0, ..., N-1\}`, where :math:`N` is the activation vector depth/dimensionality. - A batch of weights, which matches or can be broadcast to match the shape of the target ndarray. This arg can give uneven weighting to different items in the batch (depending, for instance, on the item's target category). The predicted category from each vector is the index of the highest-valued vector component. The layer returns a weighted average accuracy of these predictions. """ def f(model_output, targets, weights): # pylint: disable=invalid-name predictions = jnp.argmax(model_output, axis=-1) shapes.assert_same_shape(predictions, targets) ones_and_zeros = jnp.equal(predictions, targets) return jnp.sum(ones_and_zeros * weights) / jnp.sum(weights) return base.Fn('WeightedCategoryAccuracy', f)
def MacroAveragedFScore(beta=1., initial_category_index=0): r"""Returns a layer that computes a macro-averaged F-score. Args: beta: a parameter that determines the weight of recall in the F-score. initial_category_index: an index of the initial category. The layer takes two inputs: - Model output from one batch, an ndarray of float-valued elements. - A batch of element-wise target values, which matches the shape of the model output. The layer returns an macro-averaged F-score across all the classes. """ def f(model_output, targets): # pylint: disable=invalid-name def non_nan(x): # pylint: disable=invalid-name return jnp.where(jnp.isnan(x), 0., x) beta2 = beta**2 predictions = jnp.argmax(model_output, axis=-1) n_categories = model_output.shape[-1] f_scores = jnp.empty(0) for k in range(initial_category_index, n_categories): n_correct = sum((predictions == k) & (targets == k)) precision = non_nan(n_correct / sum(predictions == k)) recall = non_nan(n_correct / sum(targets == k)) f_score = non_nan((beta2 + 1) * (precision * recall) / ((beta2 * precision) + recall)) f_scores = jnp.append(f_scores, f_score) return jnp.mean(f_scores) return base.Fn('MacroAveragedFScore', f)
def InnerSRUCell(): """The inner (non-parallel) computation of an SRU.""" def f(cur_x_times_one_minus_f, cur_f, cur_state): # pylint: disable=invalid-name res = cur_f * cur_state + cur_x_times_one_minus_f return res, res return base.Fn('InnerSRUCell', f, n_out=2)
def _CrossEntropy(): """Returns a layer that computes prediction-target cross entropies.""" def f(model_output, target_category): # pylint: disable=invalid-name # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. # shapes.assert_shape_equals(target_category, model_output.shape[:-1]) target_distribution = core.one_hot(target_category, model_output.shape[-1]) return -1.0 * jnp.sum(model_output * target_distribution, axis=-1) return base.Fn('_CrossEntropy', f)
def TestModelSavingInputs(): def f(inputs): # Save the inputs for a later check. test_model_inputs.append(inputs) # Change type to np.float32 and add the logit dimension. return jnp.broadcast_to( inputs.astype(np.float32)[:, :, None], inputs.shape + (vocab_size,) ) return layers_base.Fn('TestModelSavingInputs', f)
def MakeZeroState(depth_multiplier=1): """Makes zeros of shape like x but removing the length (axis 1).""" def f(x): # pylint: disable=invalid-name if len(x.shape) != 3: raise ValueError(f'Layer input should be a rank 3 tensor representing' f' (batch_size, sequence_length, feature_depth); ' f'instead got shape {x.shape}.') return jnp.zeros((x.shape[0], depth_multiplier * x.shape[-1]), dtype=jnp.float32) return base.Fn('MakeZeroState', f)
def test_fn_layer_example(self): layer = base.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0))) input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7))) expected_shape = ((2, 7), (4, 7)) output_shape = base.check_shape_agreement(layer, input_signature) self.assertEqual(output_shape, expected_shape) inp = (np.array([2]), np.array([3])) x, xs = layer(inp) self.assertEqual(int(x), 5) self.assertEqual([int(y) for y in xs], [2, 3])
def _BinaryCrossEntropy(): """Returns a layer that computes prediction-target cross entropies.""" def f(model_output, target_category): # pylint: disable=invalid-name shapes.assert_same_shape(model_output, target_category) batch_size = model_output.shape[0] j = jnp.dot(jnp.transpose(target_category), jnp.log(model_output)) j += jnp.dot(jnp.transpose(1 - target_category), jnp.log(1 - model_output)) j = -1.0/batch_size * jnp.squeeze(j) return j return base.Fn('_BinaryCrossEntropy', f)
def SRU(n_units, activation=None, mode='train'): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. mode: if 'predict' then we save the previous state for one-by-one inference Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), ScanSRUCell(mode=mode), cb.Select([0], n_in=2), # act(c), r, x activation if activation is not None else [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)), # Set the name to SRU and don't print sublayers. name=f'SRU_{n_units}', sublayers_to_print=[])
def _merge_heads(): """Returns a layer that undoes splitting, after multi-head computation.""" def f(x): seq_len = x.shape[1] # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.reshape((-1, n_heads, seq_len, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, n_heads * d_head)) return x return base.Fn('MergeHeads', f)
def TestModel(extra_dim): """Dummy sequence model for testing.""" def f(inputs): # Cast the input to float32 - this is for simulating discrete-input models. inputs = inputs.astype(np.float32) # Add an extra dimension if requested, e.g. the logit dimension for output # symbols. if extra_dim is not None: return jnp.broadcast_to(inputs[:, :, None], inputs.shape + (extra_dim,)) else: return inputs return layers_base.Fn('TestModel', f)
def _split_into_heads(): """Returns a layer that reshapes tensors for multi-headed computation.""" def f(x): batch_size = x.shape[0] seq_len = x.shape[1] # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) x = x.reshape((batch_size, seq_len, n_heads, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, d_head)) return x return base.Fn('SplitIntoHeads', f)
def SRU(n_units, activation=None): r"""SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: .. math:: y_t &= W x_t + B \quad \hbox{(include $B$ optionally)} \\ f_t &= \sigma(Wf x_t + bf) \\ r_t &= \sigma(Wr x_t + br) \\ c_t &= f_t \times c_{t-1} + (1 - f_t) \times y_t \\ h_t &= r_t \times \hbox{activation}(c_t) + (1 - r_t) \times x_t We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. Returns: The SRU layer. """ sigmoid_activation = activation_fns.Sigmoid() return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(sigmoid_activation, sigmoid_activation), # r, f, y, x base.Fn( '', lambda r, f, y: (y * (1.0 - f), f, r), # y * (1 - f), f, r, x n_out=3), cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn('FinalSRUGate', lambda c, r, x: c * r + x * (1 - r) * (3**0.5)))
def _WeightedSequenceMean(): """Returns a layer that computes a weighted sequence accuracy mean.""" def f(values, weights): # pylint: disable=invalid-name # This function assumes weights are 0 or 1. # Then compute 1: not-correct, 0: correct or masked not_correct = (1.0 - values) * weights axis_to_sum = list(range(1, len(not_correct.shape))) # Summing not-correct on all axes but batch. We're summing 0s and 1s, # so the sum is 0 if it's all 0 and >=1 in all other cases. not_correct_seq = jnp.sum(not_correct, axis=axis_to_sum) # Sequence is correct if not_correct_seq is 0, reverting here. correct_seq = 1.0 - jnp.minimum(1.0, not_correct_seq) return jnp.mean(correct_seq) # Mean over batch. return base.Fn('_WeightedSequenceMean', f)
def SRU(n_units, activation=None, rescale=False, highway_bias=0): """SRU (Simple Recurrent Unit) layer as in https://arxiv.org/abs/1709.02755. As defined in the paper: (1) y_t = W x_t (+ B optionally, which we do) (2) f_t = sigmoid(Wf x_t + bf) (3) r_t = sigmoid(Wr x_t + br) (4) c_t = f_t * c_{t-1} + (1 - f_t) * y_t (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t * alpha We assume the input is of shape [batch, length, depth] and recurrence happens on the length dimension. This returns a single layer. It's best to use at least 2, they say in the paper, except inside a Transformer. Args: n_units: output depth of the SRU layer. activation: Optional activation function. rescale: To offset the problem of the gradient vanishing in the h_t as a result of light recurrence and highway computation for deeper layers, a scaling correction alpha is applied as follows: (1 + exp(highway_bias) * 2)**0.5 ref: https://arxiv.org/abs/1709.02755, page 4, section 3.2 Initialization. highway_bias: intial bias of highway gates Returns: The SRU layer. """ # pylint: disable=no-value-for-parameter return cb.Serial( # x cb.Branch(core.Dense(3 * n_units), []), # r_f_y, x cb.Split(n_items=3), # r, f, y, x cb.Parallel(core.Sigmoid(), core.Sigmoid()), # r, f, y, x base.Fn(lambda r, f, y: (y * (1.0 - f), f, r)), # y * (1 - f), f, r, x cb.Parallel([], [], cb.Branch(MakeZeroState(), [])), cb.Scan(InnerSRUCell(), axis=1), cb.Select([0], n_in=2), # act(c), r, x activation or [], base.Fn(lambda c, r, x: c * r + x * (1 - r) * ((1 + np.exp(highway_bias) * 2)**0.5 if rescale else 1)))