def BatchNorm(x, params, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, **unused_kwargs): """Layer construction function for a batch normalization layer.""" mean = np.mean(x, axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, axis, keepdims=True) var = m1 - mean**2 # x mustn't be onp.ndarray here; otherwise `x-mean` will call mean.__rsub__ # with each element of x, resulting in an onp.ndarray with dtype `object`. z = (x - mean) / np.sqrt(var + epsilon).astype(x.dtype) # Expand the parameters to have the right axes. beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] # Return the z rescaled by the parameters if requested. if center and scale: ret = gamma * z + beta elif center: ret = z + beta elif scale: ret = gamma * z else: ret = z assert ret.dtype == x.dtype, ('The dtype of the output (%s) of batch norm is ' 'not the same as the input (%s). Batch norm ' 'should not change the dtype' % (ret.dtype, x.dtype)) return ret
def BatchNorm(x, params, axis=(0, 1, 2), epsilon=1e-5, center=True, scale=True, **unused_kwargs): """Layer construction function for a batch normalization layer.""" mean = np.mean(x, axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, axis, keepdims=True) var = m1 - mean**2 z = (x - mean) / np.sqrt(var + epsilon) # Expand the parameters to have the right axes. beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] # Return the z rescaled by the parameters if requested. if center and scale: return gamma * z + beta if center: return z + beta if scale: return gamma * z return z
def apply_fun(params, inputs, **kwargs): del kwargs (scale, bias) = params mean = np.mean(inputs, axis=-1, keepdims=True) variance = np.mean((inputs - mean)**2, axis=-1, keepdims=True) norm_inputs = (inputs - mean) / np.sqrt(variance + epsilon) return norm_inputs * scale + bias
def call(self, x, params, state, **unused_kwargs): """Layer construction function for a batch normalization layer.""" running_mean, running_var, num_batches = state if self._mode == 'train': mean = np.mean(x, self._axis, keepdims=True) # Fast but less numerically-stable variance calculation than np.var. m1 = np.mean(x**2, self._axis, keepdims=True) var = m1 - mean**2 num_batches = num_batches + 1 if self._momentum is None: # A simple average over all batches seen so far exponential_average_factor = 1.0 / num_batches else: exponential_average_factor = self._momentum def average(factor, new, old): return (factor * new + (1 - factor) * old).astype(old.dtype) running_mean = average(exponential_average_factor, mean, running_mean) running_var = average(exponential_average_factor, var, running_var) state = (running_mean, running_var, num_batches) else: mean = running_mean var = running_var z = (x - mean.astype(x.dtype)) / np.sqrt(var + self._epsilon).astype( x.dtype) # Expand the parameters to have the right axes. beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in self._axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] # Return the z rescaled by the parameters if requested. if self._center and self._scale: output = gamma * z + beta elif self._center: output = z + beta elif self._scale: output = gamma * z else: output = z assert output.dtype == x.dtype, ( 'The dtype of the output (%s) of batch ' 'norm is not the same as the input (%s). ' 'Batch norm should not change the dtype' % (output.dtype, x.dtype)) return output, state
def update(self, step, grads, params, slots, opt_params): updates = [] learning_rate = opt_params["learning_rate"] beta1 = opt_params["beta1"] decay_rate = opt_params["decay_rate"] clipping_threshold = opt_params["clipping_threshold"] weight_decay_rate = opt_params["weight_decay_rate"] epsilon1 = opt_params["epsilon1"] epsilon2 = opt_params["epsilon2"] decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(params * params)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(params.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_params = (1 - weight_decay_rate) * params - subtrahend # TODO(lukaszkaiser): why is the astype needed here? Check and correct. return new_params.astype(params.dtype), updates
def masked_mean(inputs, targets, mask_id=None): """Mean of the inputs but counting only those where targets != mask_id.""" x = inputs.astype(np.float32) if mask_id is None: return np.mean(x) unmask = 1.0 - np.equal(targets, mask_id).astype(np.float32) return np.sum(x * unmask) / np.sum(unmask)
def combine(x): if len(x.shape) > 1: batch_size = x.shape[0] * x.shape[1] return np.reshape(x, [batch_size] + list(x.shape[2:])) # TODO(lukaszkaiser): is returning averages for scalars the right choice? # If it is only scalar, return the average. return np.mean(x, axis=0)
def update(self, step, grads, params, slots, opt_params): updates = [] (learning_rate, beta1, decay_rate, clipping_threshold, weight_decay_rate, epsilon1, epsilon2) = opt_params decay_rate = self._decay_rate_pow(step, exponent=decay_rate) update_scale = learning_rate if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(params * params)), epsilon2) mixing_rate = 1.0 - decay_rate grads_sqr = grads * grads + epsilon1 if self._factored and len(params.shape) >= 2: v_row = slots.pop(0) v_col = slots.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(grads_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(grads_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = (grads * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = slots.pop(0) new_v = decay_rate * v + mixing_rate * grads_sqr updates.append(new_v) y = grads * (new_v)**-0.5 if self._do_clipping: clipping_denom = (np.maximum( 1.0, np.sqrt(np.mean(y * y)) / clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._do_momentum: m = slots.pop(0) new_m = beta1 * m + (1.0 - beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_params = (1 - weight_decay_rate) * params - subtrahend return new_params, updates
def update(self, i, g, x, state): updates = [] decay_rate = self._decay_rate(i) update_scale = self._step_size(i) if self._multiply_by_parameter_scale: update_scale *= np.maximum(np.sqrt(np.mean(x * x)), self._epsilon2) mixing_rate = 1.0 - decay_rate g_sqr = g * g + self._epsilon1 if self._factored and len(x.shape) >= 2: v_row = state.pop(0) v_col = state.pop(0) new_v_row = decay_rate * v_row + mixing_rate * np.mean(g_sqr, axis=-1) new_v_col = decay_rate * v_col + mixing_rate * np.mean(g_sqr, axis=-2) updates.extend([new_v_row, new_v_col]) row_col_mean = np.mean(new_v_row, axis=-1, keepdims=True) row_factor = (new_v_row / row_col_mean)**-0.5 col_factor = (new_v_col)**-0.5 y = ( g * np.expand_dims(row_factor, axis=-1) * np.expand_dims(col_factor, axis=-2)) else: v = state.pop(0) new_v = decay_rate * v + mixing_rate * g_sqr updates.append(new_v) y = g * (new_v)**-0.5 if self._clipping_threshold is not None: clipping_denom = ( np.maximum(1.0, np.sqrt(np.mean(y * y)) / self._clipping_threshold)) y /= clipping_denom subtrahend = update_scale * y if self._beta1: m = state.pop(0) new_m = self._beta1 * m + (1.0 - self._beta1) * subtrahend subtrahend = new_m updates.append(new_m) new_x = x - subtrahend return new_x, updates
def masked_mean(inputs, targets, mask_id=None): """Mean of the inputs but counting only those where targets != mask_id.""" inputs = [x.astype(np.float32) for x in inputs] # We assume all elements in the list contribute equally. # TODO(lukaszkaiser): remove this assumption (e.g., when masks differ). length = len(inputs) if mask_id is None: # TODO(lukaszkaiser): can we just divide the sum by length? XLA optimizes? return sum([np.mean(x) / length for x in inputs]) unmask = [1.0 - np.equal(t, mask_id).astype(np.float32) for t in targets] return sum([np.sum(x * m) / (length * np.sum(m)) for x, m in zip(inputs, unmask)])
def apply_fun(params, x, **kwargs): beta, gamma = params # TODO(phawkins): np.expand_dims should accept an axis tuple. # (https://github.com/numpy/numpy/issues/12290) ed = tuple(None if i in axis else slice(None) for i in range(np.ndim(x))) beta = beta[ed] gamma = gamma[ed] mean, var = np.mean(x, axis, keepdims=True), fastvar(x, axis, keepdims=True) z = (x - mean) / np.sqrt(var + epsilon) if center and scale: return gamma * z + beta if center: return z + beta if scale: return gamma * z return z
def Mean(x, params, axis=-1, keepdims=False, **kwargs): del params, kwargs return np.mean(x, axis=axis, keepdims=keepdims)
def fastvar(x, axis, keepdims): """A fast but less numerically-stable variance calculation than np.var.""" m1 = np.mean(x**2, axis, keepdims=keepdims) m2 = np.mean(x, axis, keepdims=keepdims)**2 return m1 - m2
def crossentropy_loss(logpred, target): """Calculate crossentropy loss.""" return -np.mean( np.sum(logpred * slax.one_hot(target, logpred.shape[-1]), axis=-1))
def LayerNorm(x, params, epsilon=1e-6, **unused_kwargs): (scale, bias) = params mean = np.mean(x, axis=-1, keepdims=True) variance = np.mean((x - mean)**2, axis=-1, keepdims=True) norm_inputs = (x - mean) / np.sqrt(variance + epsilon) return norm_inputs * scale + bias
def make_unit_length(self, x, epsilon=1e-6): variance = np.mean(x**2, axis=-1, keepdims=True) norm_inputs = x / np.sqrt(variance + epsilon) return norm_inputs
def Mean(x, axis=-1, keepdims=False, **unused_kwargs): return np.mean(x, axis=axis, keepdims=keepdims)