Exemplo n.º 1
0
def _maybe_validate_target_accept_prob(target_accept_prob, validate_args):
    """Validates that target_accept_prob is in (0, 1)."""
    if not validate_args:
        return target_accept_prob
    with tf.control_dependencies([
            tf1.assert_positive(target_accept_prob,
                                message='`target_accept_prob` must be > 0.'),
            tf1.assert_less(target_accept_prob,
                            tf.ones_like(target_accept_prob),
                            message='`target_accept_prob` must be < 1.')
    ]):
        return tf.identity(target_accept_prob)
Exemplo n.º 2
0
    def __init__(self,
                 temperature,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="GSTBernoulli",
                 dtype=tf.int32):
        """Construct GSTBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature of a set of
        GSTBernoulli distributions. The temperature should be positive.
      logits: An N-D `Tensor` representing the log-odds of a positive event.
        Each entry in the `Tensor` parametrizes an independent GSTBernoulli
        distribution where the probability of an event is sigmoid(logits). Only
        one of `logits` or `probs` should be passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or more
        of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
      dtype: Type of the Tesnors.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
        with tf.name_scope(name, values=[logits, probs, temperature]) as name:
            self._temperature = tf.convert_to_tensor(temperature,
                                                     name="temperature",
                                                     dtype=dtype)
            if validate_args:
                with tf.control_dependencies([tf.assert_positive(temperature)
                                              ]):
                    self._temperature = tf.identity(self._temperature)
            super(GSTBernoulli, self).__init__(logits=logits,
                                               probs=probs,
                                               validate_args=validate_args,
                                               allow_nan_stats=allow_nan_stats,
                                               dtype=dtype,
                                               name=name)
Exemplo n.º 3
0
def ptb_producer(raw_data, batch_size, num_steps, name=None):
    """Iterate on the raw PTB data.

    This chunks up raw_data into batches of examples and returns Tensors that
    are drawn from these batches.

    Args:
      raw_data: one of the raw data outputs from ptb_raw_data.
      batch_size: int, the batch size.
      num_steps: int, the number of unrolls.
      name: the name of this operation (optional).

    Returns:
      A pair of Tensors, each shaped [batch_size, num_steps]. The second element
      of the tuple is the same data time-shifted to the right by one.

    Raises:
      tf.errors.InvalidArgumentError: if batch_size or num_steps are too high.
    """
    with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
        raw_data = tf.convert_to_tensor(raw_data,
                                        name="raw_data",
                                        dtype=tf.int32)

        data_len = tf.size(raw_data)
        batch_len = data_len // batch_size
        data = tf.reshape(raw_data[0:batch_size * batch_len],
                          [batch_size, batch_len])

        epoch_size = (batch_len - 1) // num_steps
        assertion = tf.assert_positive(
            epoch_size,
            message="epoch_size == 0, decrease batch_size or num_steps")
        with tf.control_dependencies([assertion]):
            epoch_size = tf.identity(epoch_size, name="epoch_size")

        i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
        x = tf.strided_slice(data, [0, i * num_steps],
                             [batch_size, (i + 1) * num_steps])
        x.set_shape([batch_size, num_steps])
        y = tf.strided_slice(data, [0, i * num_steps + 1],
                             [batch_size, (i + 1) * num_steps + 1])
        y.set_shape([batch_size, num_steps])
        return x, y
Exemplo n.º 4
0
    def _build(self, memory, query, memory_mask=None):
        """Perform a differentiable read.

    Args:
      memory: [batch_size, memory_size, memory_word_size]-shaped Tensor of
        dtype float32. This represents, for each example and memory slot, a
        single embedding to attend over.
      query: [batch_size, query_word_size]-shaped Tensor of dtype float32.
        Represents, for each example, a single embedding representing a query.
      memory_mask: None or [batch_size, memory_size]-shaped Tensor of dtype
        bool. An entry of False indicates that a memory slot should not enter
        the resulting weighted sum. If None, all memory is used.

    Returns:
      An AttentionOutput instance containing:
        read: [batch_size, memory_word_size]-shaped Tensor of dtype float32.
          This represents, for each example, a weighted sum of the contents of
          the memory.
        weights: [batch_size, memory_size]-shaped Tensor of dtype float32. This
          represents, for each example and memory slot, the attention weights
          used to compute the read.
        weight_logits: [batch_size, memory_size]-shaped Tensor of dtype float32.
          This represents, for each example and memory slot, the logits of the
          attention weights, that is, `weights` is calculated by taking the
          softmax of the weight logits.

    Raises:
      UnderspecifiedError: if memory_word_size or query_word_size can not be
        inferred.
      IncompatibleShapeError: if memory, query, memory_mask, or output of
        attention_logit_mod do not match expected shapes.
    """
        if len(memory.get_shape()) != 3:
            raise base.IncompatibleShapeError(
                "memory must have shape [batch_size, memory_size, memory_word_size]."
            )

        if len(query.get_shape()) != 2:
            raise base.IncompatibleShapeError(
                "query must have shape [batch_size, query_word_size].")

        if memory_mask is not None and len(memory_mask.get_shape()) != 2:
            raise base.IncompatibleShapeError(
                "memory_mask must have shape [batch_size, memory_size].")

        # Ensure final dimensions are defined, else the attention logit module will
        # be unable to infer input size when constructing variables.
        inferred_memory_word_size = memory.get_shape()[2].value
        inferred_query_word_size = query.get_shape()[1].value
        if inferred_memory_word_size is None or inferred_query_word_size is None:
            raise base.UnderspecifiedError(
                "memory_word_size and query_word_size must be known at graph "
                "construction time.")

        memory_shape = tf.shape(memory)
        batch_size = memory_shape[0]
        memory_size = memory_shape[1]

        query_shape = tf.shape(query)
        query_batch_size = query_shape[0]

        # Transform query to have same number of words as memory.
        #
        # expanded_query: [batch_size, memory_size, query_word_size].
        expanded_query = tf.tile(tf.expand_dims(query, dim=1),
                                 [1, memory_size, 1])

        # Compute attention weights for each memory slot.
        #
        # attention_weight_logits: [batch_size, memory_size]
        with tf.control_dependencies(
            [tf.assert_equal(batch_size, query_batch_size)]):
            concatenated_embeddings = tf.concat(
                values=[memory, expanded_query], axis=2)

        batch_apply_attention_logit = basic.BatchApply(
            self._attention_logit_mod,
            n_dims=2,
            name="batch_apply_attention_logit")
        attention_weight_logits = batch_apply_attention_logit(
            concatenated_embeddings)

        # Note: basic.BatchApply() will automatically reshape the [batch_size *
        # memory_size, 1]-shaped result of self._attention_logit_mod(...) into a
        # [batch_size, memory_size, 1]-shaped Tensor. If
        # self._attention_logit_mod(...) returns something with more dimensions,
        # then attention_weight_logits will have extra dimensions, too.
        if len(attention_weight_logits.get_shape()) != 3:
            raise base.IncompatibleShapeError(
                "attention_weight_logits must be a rank-3 Tensor. Are you sure that "
                "attention_logit_mod() returned [batch_size * memory_size, 1]-shaped"
                " Tensor?")

        # Remove final length-1 dimension.
        attention_weight_logits = tf.squeeze(attention_weight_logits, [2])

        # Mask out ignored memory slots by assigning them very small logits. Ensures
        # that every example has at least one valid memory slot, else we'd end up
        # averaging all memory slots equally.
        if memory_mask is not None:
            num_remaining_memory_slots = tf.reduce_sum(tf.cast(memory_mask,
                                                               dtype=tf.int32),
                                                       axis=[1])
            with tf.control_dependencies(
                [tf.assert_positive(num_remaining_memory_slots)]):
                finfo = np.finfo(np.float32)
                kept_indices = tf.cast(memory_mask, dtype=tf.float32)
                ignored_indices = tf.cast(tf.logical_not(memory_mask),
                                          dtype=tf.float32)
                lower_bound = finfo.max * kept_indices + finfo.min * ignored_indices
                attention_weight_logits = tf.minimum(attention_weight_logits,
                                                     lower_bound)

        # attended_memory: [batch_size, memory_word_size].
        attention_weight = tf.reshape(tf.nn.softmax(attention_weight_logits),
                                      shape=[batch_size, memory_size, 1])
        # The multiplication is elementwise and relies on broadcasting the weights
        # across memory_word_size. Then we sum across the memory slots.
        attended_memory = tf.reduce_sum(memory * attention_weight, axis=[1])

        # Infer shape of result as much as possible.
        inferred_batch_size, _, inferred_memory_word_size = (
            memory.get_shape().as_list())
        attended_memory.set_shape(
            [inferred_batch_size, inferred_memory_word_size])

        return AttentionOutput(read=attended_memory,
                               weights=tf.squeeze(attention_weight, [2]),
                               weight_logits=attention_weight_logits)