def one_hot_multiply(inputs, scale): """Performs (inputs * scale) % vocab_size in the one-hot space. Args: inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot Tensor. scale: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot Tensor specifying how much to scale the corresponding one-hot vector in inputs. Soft values perform a "weighted scale": for example, scale=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * scaling by zero; 0.3 * scaling by one; and 0.5 * scaling by two. Returns: Tensor of same shape and dtype as inputs. """ # TODO(trandustin): Implement with circular conv1d. inputs = tf.convert_to_tensor(inputs) scale = tf.cast(scale, inputs.dtype) batch_shape = inputs.shape[:-1].as_list() vocab_size = inputs.shape[-1] # Form a [..., vocab_size, vocab_size] tensor. The ith row of the # batched vocab_size x vocab_size matrix represents scaling inputs by i. permutation_matrix = tf.math.floormod( tf.tile(tf.range(vocab_size)[:, tf.newaxis], [1, vocab_size]) * tf.range(vocab_size)[tf.newaxis], vocab_size) permutation_matrix = tf.one_hot(permutation_matrix, depth=vocab_size, axis=-1) # Scale the inputs according to the permutation matrix of all possible scales. scaled_inputs = tf.einsum('...v,avu->...au', inputs, permutation_matrix) scaled_inputs = tf.concat([tf.zeros(batch_shape + [1, vocab_size]), scaled_inputs[..., 1:, :]], axis=-2) # Reduce rows of the scaled inputs by the scale values. This forms a # weighted linear combination of scaling by zero, scaling by one, and so on. outputs = tf.einsum('...v,...vu->...u', scale, scaled_inputs) return outputs
def train_step(self, regularizer: float = 1e-6): # Solve primal form min (1-g) * E[nu0] + E[(B nu - nu)^2]. with tf.GradientTape() as tape: nu_sigma = tf.sqrt(tf.exp(self._nu_log_sigma)) eps = tf.random.normal(tf.shape(nu_sigma), 0, self._eps_std) nu = self._nu_mu + nu_sigma * eps init_nu_loss = tf.einsum('m,m', (1 - self._gamma) * self._initial_weights, nu) residuals = tf.einsum('n,nm->m', nu, self._td_residuals) bellman_loss = 0.5 * tf.einsum('m,m', residuals, residuals) prior_sigma = tf.sqrt(tf.exp(self._prior_log_sigma)) prior_var = tf.square(prior_sigma) prior_var = 1. neg_kl = (0.5 * (1. - 2. * tf.math.log(prior_sigma / nu_sigma + 1e-8) - (self._nu_mu - self._prior_mu)**2 / prior_var - nu_sigma**2 / prior_var)) loss = init_nu_loss + bellman_loss - self._kl_regularizer * neg_kl grads = tape.gradient(loss, [ self._nu_mu, self._nu_log_sigma, self._prior_mu, self._prior_log_sigma ]) self._nu_optimizer.apply_gradients( zip(grads, [ self._nu_mu, self._nu_log_sigma, self._prior_mu, self._prior_log_sigma ])) return loss
def call(self, input_tensor): """Constructor for dense layer with 3D kernel. Args: input_tensor: float Tensor of shape [batch, seq_length, hidden_size]. Returns: float logits Tensor. """ hidden_size = self.num_attention_heads * self.size_per_head reshape_w = tf.reshape( self.w, [hidden_size, self.num_attention_heads, self.size_per_head]) if self.head_first: ret = tf.einsum("abc,cde->adbe", input_tensor, reshape_w) else: ret = tf.einsum("abc,cde->abde", input_tensor, reshape_w) if self.use_bias: if self.head_first: reshape_b = tf.reshape( self.b, [1, self.num_attention_heads, 1, self.size_per_head]) else: reshape_b = tf.reshape( self.b, [self.num_attention_heads, self.size_per_head]) ret += reshape_b if self.activation is not None: return self.activation(ret) else: return ret
def cost(self, x): self.beta = tf.reshape(x, (x.shape[0], 1)) beta_b_beta = (self.beta_b - self.beta) beta_b_beta_transpose = tf.transpose(beta_b_beta) beta_b_beta_Q_0 = tf.matmul(beta_b_beta_transpose, self.Q_inv) J_b = tf.matmul(beta_b_beta_Q_0, beta_b_beta) J_0 = 0 if self.window == 1: Y_new = tf.reshape(self.Y[self.t], [self.n_dim, 1]) residual = Y_new - tf.matmul(self.X[0], self.beta) residual_t = tf.transpose(residual) residula_sigma = tf.matmul(residual_t, self.sigma_inv) J_0 = tf.matmul(residula_sigma, residual) else: effective_window = self.window if self.t + effective_window > self.time: effective_window = self.time - self.t t = self.t tau = self.t + effective_window x_beta = tf.einsum("abc,cd->abd", self.X, self.beta) Y_new = tf.reshape(self.Y[t:tau], [effective_window, self.n_dim, 1]) residual = Y_new - x_beta residual_t = tf.transpose(residual, perm=[0, 2, 1]) residual_sigma = tf.einsum("abc,cd->abd", residual_t, self.sigma_inv) total_residual_cost = tf.einsum("abc,acd->abd", residual_sigma, residual) J_0 = tf.reduce_sum(total_residual_cost, 0) return tf.reshape(J_b + J_0, ())
def _variance(self): with tf.control_dependencies(self._runtime_assertions): probs = self._marginal_hidden_probs() # probs :: num_steps batch_shape num_states means = self._observation_distribution.mean() # means :: observation_batch_shape[:-1] num_states # observation_event_shape means_shape = tf.concat([ self.batch_shape_tensor(), [self._num_states], self._observation_distribution.event_shape_tensor() ], axis=0) means = tf.broadcast_to(means, means_shape) # means :: batch_shape num_states observation_event_shape observation_event_shape = ( self._observation_distribution.event_shape_tensor()) batch_size = tf.reduce_prod(self.batch_shape_tensor()) flat_probs_shape = [self._num_steps, batch_size, self._num_states] flat_means_shape = [ batch_size, 1, self._num_states, tf.reduce_prod(observation_event_shape) ] flat_probs = tf.reshape(probs, flat_probs_shape) # flat_probs :: num_steps batch_size num_states flat_means = tf.reshape(means, flat_means_shape) # flat_means :: batch_size 1 num_states observation_event_size flat_mean = tf.einsum("ijk,jmkl->jiml", flat_probs, flat_means) # flat_mean :: batch_size num_steps 1 observation_event_size variances = self._observation_distribution.variance() variances = tf.broadcast_to(variances, means_shape) # variances :: batch_shape num_states observation_event_shape flat_variances = tf.reshape(variances, flat_means_shape) # flat_variances :: batch_size 1 num_states observation_event_size # For a mixture of n distributions with mixture probabilities # p[i], and where the individual distributions have means and # variances given by mean[i] and var[i], the variance of # the mixture is given by: # # var = sum i=1..n p[i] * ((mean[i] - mean)**2 + var[i]**2) flat_variance = tf.einsum("ijk,jikl->jil", flat_probs, (flat_means - flat_mean)**2 + flat_variances) # flat_variance :: batch_size num_steps observation_event_size unflat_mean_shape = tf.concat([ self.batch_shape_tensor(), [self._num_steps], observation_event_shape ], axis=0) # returns :: batch_shape num_steps observation_event_shape return tf.reshape(flat_variance, unflat_mean_shape)
def original_full_attention(query_layer, key_layer, value_layer, attention_mask, size_per_head, attention_probs_dropout_prob): """Full quadratic attention calculation. Args: query_layer: float Tensor of shape [batch_size, num_attention_heads, from_seq_length, size_per_head] key_layer: float Tensor of shape [batch_size, num_attention_heads, to_seq_length, size_per_head] value_layer: float Tensor of shape [batch_size, num_attention_heads, to_seq_length, size_per_head] attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length, to_seq_length]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. size_per_head: (optional) int. Size of each attention head. attention_probs_dropout_prob: (optional) float. Dropout probability of the attention probabilities. Returns: float Tensor of shape [batch_size, from_seq_length, num_attention_heads, size_per_head]. """ # Directly take n^2 dot product between "query" and "key". attention_scores = tf.einsum("BNFH,BNTH->BNFT", query_layer, key_layer) attention_scores = tf.multiply(attention_scores, 1.0 / np.sqrt(float(size_per_head))) if attention_mask is not None: # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_scores += adder # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] attention_probs = tf.nn.softmax(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = utils.dropout(attention_probs, attention_probs_dropout_prob) # `context_layer` = [B, F, N, H] context_layer = tf.einsum("BNFT,BNTH->BFNH", attention_probs, value_layer) return context_layer
def call(self, inputs): from_tensor = inputs[0] to_tensor = inputs[1] attention_mask = inputs[2] if len(inputs) == 3 else None # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` # `query_tensor` = [B, F, N ,H] query_tensor = self._query_dense(from_tensor) # `key_tensor` = [B, T, N, H] key_tensor = self._key_dense(to_tensor) # `value_tensor` = [B, T, N, H] value_tensor = self._value_dense(to_tensor) # Take the dot product between "query" and "key" to get the raw # attention scores. # `attention_scores` = [B, N, F, T] if tf.config.experimental.list_physical_devices("GPU"): # `query_tensor` = [B, N, F, H] query_tensor = tf.transpose(query_tensor, [0, 2, 1, 3]) # `key_tensor` = [B, N, T, H] key_tensor = tf.transpose(key_tensor, [0, 2, 1, 3]) attention_scores = tf.matmul(query_tensor, key_tensor, transpose_b=True) else: attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(self._head_size))) # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] attention_probs = self._masked_softmax( [attention_scores, attention_mask]) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self._dropout(attention_probs) # `context_layer` = [B, F, N, H] return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
def infnet(x): # MuE encoder. v = mue.encode(x, uln0, rln0, lln, latent_length, latent_alphabet_size, alphabet_size, padded_data_length, transfer_mats, dtype, eps) # Construct the approximate posterior using the inference network # parameters. Softplus ensures scale parameter is positive. loc = tf.einsum('jk,jkl->l', v, mean_fac) scale = tf.nn.softplus(tf.einsum('jk,jkl->l', v, scale_fac)) if z_distr == 'Normal': return Normal(loc, scale, name=name) elif z_distr == 'Laplace': return Laplace(loc, scale, name=name) elif z_distr == 'Exponential': return Gamma(tf.nn.softplus(loc), tf.nn.softplus(scale), name=name)
def create_rand_mask_from_inputs(from_blocked_mask, to_blocked_mask, rand_attn, num_attention_heads, num_rand_blocks, batch_size, from_seq_length, from_block_size): """Create 3D attention mask from a 2D tensor mask. Args: from_blocked_mask: 2D Tensor of shape [batch_size, from_seq_length//from_block_size, from_block_size]. to_blocked_mask: int32 Tensor of shape [batch_size, to_seq_length//to_block_size, to_block_size]. rand_attn: [batch_size, num_attention_heads, from_seq_length//from_block_size-2, num_rand_blocks] num_attention_heads: int. Number of attention heads. num_rand_blocks: int. Number of random chunks per row. batch_size: int. Batch size for computation. from_seq_length: int. length of from sequence. from_block_size: int. size of block in from sequence. Returns: float Tensor of shape [batch_size, num_attention_heads, from_seq_length//from_block_size-2, from_block_size, num_rand_blocks*to_block_size]. """ num_windows = from_seq_length // from_block_size - 2 rand_mask = tf.reshape(tf.gather(to_blocked_mask, rand_attn, batch_dims=1), [ batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size ]) rand_mask = tf.einsum("BLQ,BHLK->BHLQK", from_blocked_mask[:, 1:-1], rand_mask) return rand_mask
def testDenseDVIMoments(self): """Verifies DenseDVI's moments empirically with samples.""" tf.random.set_seed(377269) batch_size = 3 num_features = 5 units = 128 num_samples = 50000 inputs = tf.cast(np.random.rand(batch_size, num_features), dtype=tf.float32) layer = ed.layers.DenseDVI(units, activation=tf.nn.relu) outputs1 = layer(inputs) mean1 = outputs1.distribution.mean() covariance1 = outputs1.distribution.covariance() kernel_samples = layer.kernel.distribution.sample(num_samples) outputs2 = layer.activation( tf.einsum("bd,sdu->sbu", inputs, kernel_samples) + tf.reshape(layer.bias, [1, 1, units])) mean2 = tf.reduce_mean(outputs2, axis=0) centered_outputs2 = tf.transpose(a=outputs2 - mean2, perm=[1, 2, 0]) covariance2 = tf.matmul(centered_outputs2, centered_outputs2, transpose_b=True) / float(num_samples) # Check % of mismatches is not too high according to heuristic thresholds. num_mismatches = np.sum(np.abs(mean1 - mean2) > 5e-3) percent_mismatches = num_mismatches / float(batch_size * units) self.assertLessEqual(percent_mismatches, 0.05) num_mismatches = np.sum(np.abs(covariance1 - covariance2) > 5e-3) percent_mismatches = num_mismatches / float(batch_size * units * units) self.assertLessEqual(percent_mismatches, 0.05)
def call(self, input_tensor): """Forward pass for dense layer with 2D kernel. Args: input_tensor: Float tensor with rank 3. Returns: float logits Tensor. """ if self.w is None: last_dim = get_shape_list(input_tensor)[-1] self.w = tf.compat.v1.get_variable( name="kernel", shape=[last_dim, self.output_size], initializer=self.initializer) self.initializer = None self._trainable_weights.append(self.w) ret = tf.einsum("abc,cd->abd", input_tensor, self.w) if self.use_bias: if self.b is None: self.b = tf.compat.v1.get_variable( name="bias", shape=[self.output_size], initializer=tf.zeros_initializer) self._trainable_weights.append(self.b) ret += self.b if self.activation is not None: return self.activation(ret) else: return ret
def one_hot_minus(inputs, shift): """Performs (inputs - shift) % vocab_size in the one-hot space. Args: inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot Tensor. shift: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot Tensor specifying how much to shift the corresponding one-hot vector in inputs. Soft values perform a "weighted shift": for example, shift=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * shifting by zero; 0.3 * shifting by one; and 0.5 * shifting by two. Returns: Tensor of same shape and dtype as inputs. """ # TODO(trandustin): Implement with circular conv1d. inputs = tf.convert_to_tensor(inputs) shift = tf.cast(shift, inputs.dtype) vocab_size = inputs.shape[-1] if isinstance(vocab_size, tf1.Dimension): vocab_size = vocab_size.value # Form a [..., vocab_size, vocab_size] matrix. Each batch element of # inputs will vector-matrix multiply the vocab_size x vocab_size matrix. This # "shifts" the inputs batch element by the corresponding shift batch element. shift_matrix = tf.stack( [tf.roll(shift, i, axis=-1) for i in range(vocab_size)], axis=-2) outputs = tf.einsum('...v,...uv->...u', inputs, shift_matrix) return outputs
def call(self, inputs): ret = tf.einsum(self.equation, inputs, self.kernel) if self.bias is not None: ret += self.bias if self.activation is not None: ret = self.activation(ret) return ret
def call(self, inputs): ret = tf.einsum(self._einsum_string, inputs, self._kernel) if self._use_bias: ret += self._bias if self._activation is not None: ret = self._activation(ret) return ret
def test_batching(self, input_batch_shape, kernel_batch_shape): input_shape = (12, 12, 2) filter_shape = (3, 3) channels_out = 4 strides = 2 dilations = (1, 1) padding = 'SAME' x, k = _make_input_and_kernel( self.make_input, input_batch_shape=input_batch_shape, input_shape=input_shape, kernel_batch_shape=kernel_batch_shape, filter_shape=filter_shape, channels_out=channels_out, dtype=self.dtype) conv_fn = self.make_conv_fn(filter_shape, strides, padding, dilations) y_batched = conv_fn(x, k) broadcast_batch_shape = ps.broadcast_shape( input_batch_shape, kernel_batch_shape) broadcasted_input = tf.broadcast_to( x, shape=ps.concat([broadcast_batch_shape, input_shape], axis=0)) broadcasted_kernel = tf.broadcast_to( k, shape=ps.concat([broadcast_batch_shape, ps.shape(k)[-2:]], axis=0)) flat_y = tf.reshape( y_batched, shape=ps.pad( ps.shape(y_batched)[-3:], paddings=[[1, 0]], constant_values=-1)) flat_x = tf.reshape( broadcasted_input, shape=ps.pad(input_shape, paddings=[[1, 0]], constant_values=-1)) flat_tf_kernel = tf.einsum( '...ij->...ji', tf.reshape( broadcasted_kernel, shape=ps.concat( [(-1,), filter_shape, (input_shape[-1], channels_out)], axis=0))) rank = 2 output_shape, strides_ = convolution_util._get_output_shape( rank=rank, strides=(strides,) * rank, padding=padding, dilations=dilations, input_shape=input_shape, output_size=channels_out, filter_shape=filter_shape) y_expected = tf.vectorized_map( lambda args: tf.nn.conv2d_transpose( # pylint: disable=g-long-lambda args[0][tf.newaxis], args[1], output_shape=ps.concat([[1], output_shape], axis=0), strides=strides_, padding=padding), elems=(flat_x, flat_tf_kernel)) [y_actual_, y_expected_] = self.evaluate( [flat_y, tf.squeeze(y_expected, axis=1)]) self.assertAllClose(y_expected_, y_actual_, rtol=1e-5, atol=0)
def laplace_attention(q, k, v, scale, normalise): """Computes laplace exponential attention. Args: q: queries. Tensor of shape [batch_size, m, d_k]. k: keys. Tensor of shape [batch_size, n, d_k]. v: values. Tensor of shape [batch_size, n, d_v]. scale: float that scales the L1 distance. normalise: Boolean that determines whether weights sum to 1. Returns: Tensor of shape [batch_size, m, d_v]. """ k = tf.expand_dims(k, axis=1) # [batch_size, 1, n, d_k] q = tf.expand_dims(q, axis=2) # [batch_size, m, 1, d_k] unnorm_weights = -tf.abs((k - q) / scale) # [batch_size, m, n, d_k] unnorm_weights = tf.reduce_sum(unnorm_weights, axis=-1) # [batch_size, m, n] if normalise: weight_fn = tf.nn.softmax else: weight_fn = lambda x: 1 + tf.tanh(x) weights = weight_fn(unnorm_weights) # [batch_size, m, n] rep = tf.einsum('bik,bkj->bij', weights, v) # [batch_size, m, d_v] return rep
def _batch_outer_product(target, event_ndims): """Calculates the batch outer product along `target`'s event dimensions. More precisely, A `tf.einsum` operation is used to calculate desired pairwise products as follows: For `event_ndims=0`, the return value is: `tf.einsum("...,...->...", target, target)` For `event_ndims=1`: `tf.einsum("...a,...b->...ab", target, target)` For `event_ndims=2`: `tf.einsum("...ab,...cd->...abcd", target, target)` ... Args: target: Target `Tensor` for the `tf.einsum` computation. event_ndims: Both the number of dimensions that specify the event shape and the desired number of dimensions for cross product terms. Returns: outer_product: A `Tensor` with shape B + E + E for all pairwise products of `target` in the event dimensions. """ assign_indices = ''.join( list(map(chr, range(ord('a'), ord('a') + event_ndims * 2)))) first_indices = assign_indices[:event_ndims] second_indices = assign_indices[event_ndims:] einsum_formula = '...{},...{}->...{}'.format(first_indices, second_indices, assign_indices) return tf.einsum(einsum_formula, target, target)
def _compute_attention( self, query, key, value, attention_mask=None, training=None ): """Applies Dot-product attention with query, key, value tensors. This function defines the computation inside `call` with projected multi-head Q, K, V inputs. Users can override this function for customized attention implementation. Args: query: Projected query `Tensor` of shape `(B, T, N, key_dim)`. key: Projected key `Tensor` of shape `(B, S, N, key_dim)`. value: Projected value `Tensor` of shape `(B, S, N, value_dim)`. attention_mask: a boolean mask of shape `(B, T, S)`, that prevents attention to certain positions. It is generally not needed if the `query` and `value` (and/or `key`) are masked. training: Python boolean indicating whether the layer should behave in training mode (adding dropout) or in inference mode (doing nothing). Returns: attention_output: Multi-headed outputs of attention computation. attention_scores: Multi-headed attention weights. """ # Note: Applying scalar multiply at the smaller end of einsum improves # XLA performance, but may introduce slight numeric differences in # the Transformer attention head. query = tf.multiply(query, 1.0 / math.sqrt(float(self._key_dim))) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = tf.einsum(self._dot_product_equation, key, query) attention_scores = self._masked_softmax( attention_scores, attention_mask ) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_scores_dropout = self._dropout_layer( attention_scores, training=training ) # `context_layer` = [B, T, N, H] attention_output = tf.einsum( self._combine_equation, attention_scores_dropout, value ) return attention_output, attention_scores
def log_likelihood_fn(weights, features, labels, reduce_sum=True): """The log_likelihood function.""" logits = tf.einsum('nd,...d->...n', features, weights) log_likelihood = tfd.Bernoulli(logits=logits).log_prob(labels) if reduce_sum: return tf.reduce_sum(log_likelihood, [-1]) else: return log_likelihood
def call(self, input_tensor): """Constructor for dense layer with 3D kernel. Args: input_tensor: float Tensor of shape [batch, seq_length, hidden_size]. Returns: float logits Tensor. """ last_dim = get_shape_list(input_tensor)[-1] if self.w is None: self.w = tf.compat.v1.get_variable( name="kernel", shape=[ last_dim, self.num_attention_heads * self.size_per_head ], initializer=self.initializer) self.initializer = None self._trainable_weights.append(self.w) reshape_w = tf.reshape( self.w, [last_dim, self.num_attention_heads, self.size_per_head]) if self.head_first: ret = tf.einsum("abc,cde->adbe", input_tensor, reshape_w) else: ret = tf.einsum("abc,cde->abde", input_tensor, reshape_w) if self.use_bias: if self.b is None: self.b = tf.compat.v1.get_variable( name="bias", shape=[self.num_attention_heads * self.size_per_head], initializer=tf.zeros_initializer) self._trainable_weights.append(self.b) if self.head_first: reshape_b = tf.reshape( self.b, [1, self.num_attention_heads, 1, self.size_per_head]) else: reshape_b = tf.reshape( self.b, [self.num_attention_heads, self.size_per_head]) ret += reshape_b if self.activation is not None: return self.activation(ret) else: return ret
def model_fn(features): unscaled_weights = yield root(tfd.Independent(tfd.Normal(zero, one), 1)) local_scales = yield root(tfd.Independent(tfd.Gamma(half, half), 1)) global_scale = yield root(tfd.Gamma(0.5, 0.5)) weights = unscaled_weights * local_scales * global_scale[Ellipsis, tf.newaxis] logits = tf.einsum('nd,...d->...n', features, weights) yield tfd.Independent(tfd.Bernoulli(logits=logits), 1)
def model_coroutine(): beta = yield root(tfd.Sample(tfd.Normal(0, 1), [p], name='beta')) alpha = yield root(tfd.Normal(0, 1, name='alpha')) kappa = yield root(tfd.Gamma(1, 1, name='kappa')) mu = tf.math.sigmoid(alpha[..., tf.newaxis] + tf.einsum('...p,np->...n', beta, x)) yield tfd.Independent(beta_proportion(mu, kappa[..., tf.newaxis]), reinterpreted_batch_ndims=1, name='prob')
def call(self, inputs): reshaped_inputs = tf.reshape(inputs, [tf.shape(inputs)[0], -1]) reshaped_inputs = ( reshaped_inputs - tf.reduce_mean(reshaped_inputs, 0, keepdims=True)) s = power_iterate(reshaped_inputs, self.u_var) variance_explained_ratio = s / tf.einsum('nc,nc->', reshaped_inputs, reshaped_inputs) self.add_loss(self.reg_strength * tf.nn.relu(variance_explained_ratio - self.threshold)) return inputs # Pass-through layer.
def call(self, inputs, decode_loop_step=None): from_tensor = inputs[0] to_tensor = inputs[1] attention_mask = inputs[2] if len(inputs) >= 3 else None cache = inputs[3] if len(inputs) >= 4 else None # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` # `query_tensor` = [B, F, N ,H] query_tensor = self._query_dense(from_tensor) # `key_tensor` = [B, T, N, H] key_tensor = self._key_dense(to_tensor) # `value_tensor` = [B, T, N, H] value_tensor = self._value_dense(to_tensor) if cache: key_tensor, value_tensor = self._update_cache( key_tensor, value_tensor, cache, decode_loop_step) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(self._head_size))) # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] attention_probs = self._masked_softmax( [attention_scores, attention_mask]) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self._dropout(attention_probs) # `context_layer` = [B, F, N, H] return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor), cache
def call(self, inputs): x = inputs B, H, W, C = x.shape h = self.normalize(x) q = self.nin_q(h) k = self.nin_k(h) v = self.nin_v(h) w = tf.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C)**(-0.5)) w = tf.reshape(w, [B, H, W, H * W]) w = tf.nn.softmax(w, -1) w = tf.reshape(w, [B, H, W, H, W]) h = tf.einsum('bhwHW,bHWc->bhwc', w, v) h = self.nin_proj_out(h) assert h.shape == x.shape return x + h
def body_fn(vecs, i): # Slice out the vector w.r.t. which we're orthogonalizing the rest. u = tf.math.l2_normalize(vecs[..., i, tf.newaxis], axis=-2) # Find weights by dotting the d x 1 against the d x n. weights = tf.einsum('...dm,...dn->...n', u, vecs) # Project out vector `u` from the trailing vectors. masked_weights = tf.where( tf.range(n) > i, weights, 0.)[..., tf.newaxis, :] vecs = vecs - tf.math.multiply_no_nan(u, masked_weights) tensorshape_util.set_shape(vecs, vectors.shape) return vecs, i + 1
def _weighted_sum(self, alphas, v, contract_dim_a=-3, contract_dim_v=-3): num_batch_axes = len(alphas.shape) + contract_dim_a pre_str = 'abcdefghij'[:num_batch_axes] in_dim_a = -contract_dim_a - 2 in_dim_v = -contract_dim_v - 2 in_str_a = 'zyxwv'[:in_dim_a] in_str_v = 'zyxwv'[:in_dim_v] einsum_str = '{}Q{}M,{}M{}C->{}Q{}C'.format(pre_str, in_str_a, pre_str, in_str_v, pre_str, in_str_a) return tf.einsum(einsum_str, alphas, v)
def grad(dy): """Compute the gradient of the expectation via integration by parts.""" output, noise_grad = perturbed_output, noise_gradient # Adds dummy feature/channel dimension internally. if perturbed_input.shape.rank > output.shape.rank: dy = tf.expand_dims(dy, axis=-1) output = tf.expand_dims(output, axis=-1) # Adds dummy batch dimension internally. if not batched: dy = tf.expand_dims(dy, axis=0) # Flattens [D1, ..., Dk] to a single feat dim [D]. flatten = lambda t: tf.reshape(t, (tf.shape(t)[0], tf.shape(t)[ 1], -1)) dy = tf.reshape(dy, (tf.shape(dy)[0], -1)) # (B, D) output = flatten(output) # (N, B, D) noise_grad = flatten(noise_grad) # (N, B, D) g = tf.einsum('nbd,nb->bd', noise_grad, tf.einsum('nbd,bd->nb', output, dy)) g /= sigma * num_samples return tf.reshape(g, original_input_shape)
def log_likelihood_fn(weights, features, labels, reduce_sum=True): """The log_likelihood function.""" features = tf.convert_to_tensor(features, tf.float32) features = _add_bias(features) labels = tf.convert_to_tensor(labels) logits = tf.einsum('nd,...d->...n', features, weights) log_likelihood = tfd.Bernoulli(logits=logits).log_prob(labels) if reduce_sum: return tf.reduce_sum(log_likelihood, [-1]) else: return log_likelihood
def _dot_product(self, q, k, contract_dim_q=-3, contract_dim_k=-3): num_batch_axes = len(q.shape) + contract_dim_q pre_str = 'abcdefghij'[:num_batch_axes] in_dim_q = -contract_dim_q - 2 in_dim_k = -contract_dim_k - 2 in_str_q = 'zyxwv'[:in_dim_q] in_str_k = 'zyxwv'[:in_dim_k] einsum_str = '{}Q{}C,{}M{}C->{}Q{}M'.format(pre_str, in_str_q, pre_str, in_str_k, pre_str, in_str_q) return tf.einsum(einsum_str, q, k)