def _maybe_validate_target_accept_prob(target_accept_prob, validate_args): """Validates that target_accept_prob is in (0, 1).""" if not validate_args: return target_accept_prob with tf.control_dependencies([ tf1.assert_positive(target_accept_prob, message='`target_accept_prob` must be > 0.'), tf1.assert_less(target_accept_prob, tf.ones_like(target_accept_prob), message='`target_accept_prob` must be < 1.') ]): return tf.identity(target_accept_prob)
def __init__(self, temperature, logits=None, probs=None, validate_args=False, allow_nan_stats=True, name="GSTBernoulli", dtype=tf.int32): """Construct GSTBernoulli distributions. Args: temperature: An 0-D `Tensor`, representing the temperature of a set of GSTBernoulli distributions. The temperature should be positive. logits: An N-D `Tensor` representing the log-odds of a positive event. Each entry in the `Tensor` parametrizes an independent GSTBernoulli distribution where the probability of an event is sigmoid(logits). Only one of `logits` or `probs` should be passed in. probs: An N-D `Tensor` representing the probability of a positive event. Each entry in the `Tensor` parameterizes an independent Bernoulli distribution. Only one of `logits` or `probs` should be passed in. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. dtype: Type of the Tesnors. Raises: ValueError: If both `probs` and `logits` are passed, or if neither. """ with tf.name_scope(name, values=[logits, probs, temperature]) as name: self._temperature = tf.convert_to_tensor(temperature, name="temperature", dtype=dtype) if validate_args: with tf.control_dependencies([tf.assert_positive(temperature) ]): self._temperature = tf.identity(self._temperature) super(GSTBernoulli, self).__init__(logits=logits, probs=probs, validate_args=validate_args, allow_nan_stats=allow_nan_stats, dtype=dtype, name=name)
def ptb_producer(raw_data, batch_size, num_steps, name=None): """Iterate on the raw PTB data. This chunks up raw_data into batches of examples and returns Tensors that are drawn from these batches. Args: raw_data: one of the raw data outputs from ptb_raw_data. batch_size: int, the batch size. num_steps: int, the number of unrolls. name: the name of this operation (optional). Returns: A pair of Tensors, each shaped [batch_size, num_steps]. The second element of the tuple is the same data time-shifted to the right by one. Raises: tf.errors.InvalidArgumentError: if batch_size or num_steps are too high. """ with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]): raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32) data_len = tf.size(raw_data) batch_len = data_len // batch_size data = tf.reshape(raw_data[0:batch_size * batch_len], [batch_size, batch_len]) epoch_size = (batch_len - 1) // num_steps assertion = tf.assert_positive( epoch_size, message="epoch_size == 0, decrease batch_size or num_steps") with tf.control_dependencies([assertion]): epoch_size = tf.identity(epoch_size, name="epoch_size") i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue() x = tf.strided_slice(data, [0, i * num_steps], [batch_size, (i + 1) * num_steps]) x.set_shape([batch_size, num_steps]) y = tf.strided_slice(data, [0, i * num_steps + 1], [batch_size, (i + 1) * num_steps + 1]) y.set_shape([batch_size, num_steps]) return x, y
def _build(self, memory, query, memory_mask=None): """Perform a differentiable read. Args: memory: [batch_size, memory_size, memory_word_size]-shaped Tensor of dtype float32. This represents, for each example and memory slot, a single embedding to attend over. query: [batch_size, query_word_size]-shaped Tensor of dtype float32. Represents, for each example, a single embedding representing a query. memory_mask: None or [batch_size, memory_size]-shaped Tensor of dtype bool. An entry of False indicates that a memory slot should not enter the resulting weighted sum. If None, all memory is used. Returns: An AttentionOutput instance containing: read: [batch_size, memory_word_size]-shaped Tensor of dtype float32. This represents, for each example, a weighted sum of the contents of the memory. weights: [batch_size, memory_size]-shaped Tensor of dtype float32. This represents, for each example and memory slot, the attention weights used to compute the read. weight_logits: [batch_size, memory_size]-shaped Tensor of dtype float32. This represents, for each example and memory slot, the logits of the attention weights, that is, `weights` is calculated by taking the softmax of the weight logits. Raises: UnderspecifiedError: if memory_word_size or query_word_size can not be inferred. IncompatibleShapeError: if memory, query, memory_mask, or output of attention_logit_mod do not match expected shapes. """ if len(memory.get_shape()) != 3: raise base.IncompatibleShapeError( "memory must have shape [batch_size, memory_size, memory_word_size]." ) if len(query.get_shape()) != 2: raise base.IncompatibleShapeError( "query must have shape [batch_size, query_word_size].") if memory_mask is not None and len(memory_mask.get_shape()) != 2: raise base.IncompatibleShapeError( "memory_mask must have shape [batch_size, memory_size].") # Ensure final dimensions are defined, else the attention logit module will # be unable to infer input size when constructing variables. inferred_memory_word_size = memory.get_shape()[2].value inferred_query_word_size = query.get_shape()[1].value if inferred_memory_word_size is None or inferred_query_word_size is None: raise base.UnderspecifiedError( "memory_word_size and query_word_size must be known at graph " "construction time.") memory_shape = tf.shape(memory) batch_size = memory_shape[0] memory_size = memory_shape[1] query_shape = tf.shape(query) query_batch_size = query_shape[0] # Transform query to have same number of words as memory. # # expanded_query: [batch_size, memory_size, query_word_size]. expanded_query = tf.tile(tf.expand_dims(query, dim=1), [1, memory_size, 1]) # Compute attention weights for each memory slot. # # attention_weight_logits: [batch_size, memory_size] with tf.control_dependencies( [tf.assert_equal(batch_size, query_batch_size)]): concatenated_embeddings = tf.concat( values=[memory, expanded_query], axis=2) batch_apply_attention_logit = basic.BatchApply( self._attention_logit_mod, n_dims=2, name="batch_apply_attention_logit") attention_weight_logits = batch_apply_attention_logit( concatenated_embeddings) # Note: basic.BatchApply() will automatically reshape the [batch_size * # memory_size, 1]-shaped result of self._attention_logit_mod(...) into a # [batch_size, memory_size, 1]-shaped Tensor. If # self._attention_logit_mod(...) returns something with more dimensions, # then attention_weight_logits will have extra dimensions, too. if len(attention_weight_logits.get_shape()) != 3: raise base.IncompatibleShapeError( "attention_weight_logits must be a rank-3 Tensor. Are you sure that " "attention_logit_mod() returned [batch_size * memory_size, 1]-shaped" " Tensor?") # Remove final length-1 dimension. attention_weight_logits = tf.squeeze(attention_weight_logits, [2]) # Mask out ignored memory slots by assigning them very small logits. Ensures # that every example has at least one valid memory slot, else we'd end up # averaging all memory slots equally. if memory_mask is not None: num_remaining_memory_slots = tf.reduce_sum(tf.cast(memory_mask, dtype=tf.int32), axis=[1]) with tf.control_dependencies( [tf.assert_positive(num_remaining_memory_slots)]): finfo = np.finfo(np.float32) kept_indices = tf.cast(memory_mask, dtype=tf.float32) ignored_indices = tf.cast(tf.logical_not(memory_mask), dtype=tf.float32) lower_bound = finfo.max * kept_indices + finfo.min * ignored_indices attention_weight_logits = tf.minimum(attention_weight_logits, lower_bound) # attended_memory: [batch_size, memory_word_size]. attention_weight = tf.reshape(tf.nn.softmax(attention_weight_logits), shape=[batch_size, memory_size, 1]) # The multiplication is elementwise and relies on broadcasting the weights # across memory_word_size. Then we sum across the memory slots. attended_memory = tf.reduce_sum(memory * attention_weight, axis=[1]) # Infer shape of result as much as possible. inferred_batch_size, _, inferred_memory_word_size = ( memory.get_shape().as_list()) attended_memory.set_shape( [inferred_batch_size, inferred_memory_word_size]) return AttentionOutput(read=attended_memory, weights=tf.squeeze(attention_weight, [2]), weight_logits=attention_weight_logits)