Exemple #1
0
        def model():
            # Shared birthplace
            x_start = yield Root(
                tfd.Categorical(probs=tf.ones(n) / n, dtype=tf.int32))
            y_start = yield tfd.Categorical(probs=tf.ones(n) / n,
                                            dtype=tf.int32)

            x = m * [x_start]
            y = m * [y_start]

            for t in range(n_steps):
                for i in range(m):
                    # Construct PDF for next step in walk
                    # Start with PDF for all mass on current point.
                    ox = tf.one_hot(x[i], n)
                    oy = tf.one_hot(y[i], n)
                    o = ox[..., :, None] * oy[..., None, :]

                    # Deliberate choice of non-centered distribution as
                    # reduced symmetry lends itself to better testing.
                    p = (0.1 * tf.roll(o, shift=[0, -1], axis=[-2, -1]) +
                         0.2 * tf.roll(o, shift=[0, 1], axis=[-2, -1]) +
                         0.3 * tf.roll(o, shift=[-1, 0], axis=[-2, -1]) +
                         0.4 * tf.roll(o, shift=[1, 0], axis=[-2, -1]))

                    # Reshape just last two dimensions.
                    p = tf.reshape(p, _cat(p.shape[:-2], [-1]))
                    xy = yield tfd.Categorical(probs=p, dtype=tf.int32)
                    x[i] = xy // n
                    y[i] = xy % n

            # 2 * m noisy 2D observations at end
            for i in range(m):
                yield tfd.Normal(tf.cast(x[i], dtype=tf.float32), scale=2.0)
                yield tfd.Normal(tf.cast(y[i], dtype=tf.float32), scale=2.0)
Exemple #2
0
def overlap_add_conv1d(inputs: tf.Tensor,
                       filters: tf.Tensor,
                       fft_length: int = 4096) -> tf.Tensor:
    """FFT based convolution in 1D, using the overlap-add method.

  Args:
   inputs: a tf.Tensor<float>[batch_size, seq_length, 1] of input sequences.
   filters: a tf.Tensor<float>[filter_length, 1, channels].
   fft_length: an int, the length of the Fourier transform.

  Returns:
   A tf.Tensor<float>[batch_size, seq_length, channels] containing the response
   to the 1D convolutions with the filters.
  """
    seq_len = tf.shape(inputs)[1]
    filter_len = tf.shape(filters)[0]
    overlap = filter_len - 1
    seg_size = fft_length - overlap
    f_filters = tf.expand_dims(tf.signal.rfft(tf.transpose(filters, (1, 2, 0)),
                                              fft_length=[fft_length]),
                               axis=2)
    framed = tf.signal.frame(tf.transpose(inputs, (0, 2, 1)),
                             frame_length=seg_size,
                             frame_step=seg_size,
                             pad_end=True)
    paddings = [[0, 0], [0, 0], [0, 0], [int(overlap / 2), int(overlap / 2)]]
    framed = tf.pad(framed, paddings)
    f_inputs = tf.signal.rfft(framed, fft_length=[fft_length])
    result = tf.signal.irfft(f_inputs * tf.math.conj(f_filters))
    result = tf.roll(result, filter_len // 2 - 1, axis=-1)
    result = tf.signal.overlap_and_add(result, frame_step=seg_size)
    result = tf.transpose(result, (0, 2, 1))[..., :seq_len, :]
    output = tf.roll(result, 1 - filter_len // 2, axis=1)
    shape = tf.concat([tf.shape(inputs)[:-1], tf.shape(filters)[-1:]], axis=0)
    return tf.reshape(output, shape)
Exemple #3
0
def roll(a, shift, axis=None):  # pylint: disable=missing-docstring
  a = asarray(a).data

  if axis is not None:
    return utils.tensor_to_ndarray(tf.roll(a, shift, axis))

  # If axis is None, the roll happens as a 1-d tensor.
  original_shape = tf.shape(a)
  a = tf.roll(tf.reshape(a, [-1]), shift, 0)
  return utils.tensor_to_ndarray(tf.reshape(a, original_shape))
Exemple #4
0
def fft_conv1d(inputs: tf.Tensor, filters: tf.Tensor) -> Tuple[tf.Tensor]:
    """FFT based convolution in 1D.

  We round the input length to the closest upper power of 2 before FFT.

  Args:
   inputs: a tf.Tensor<float>[batch_size, seq_length, 1] of input sequences.
   filters: a tf.Tensor<float>[filter_length, 1, channels]

  Returns:
   A tf.Tensor<float>[batch_size, seq_length, channels] containing the response
   to the 1D convolutions with the filters, and a tf.Tensor<float>[1,] with the
   L1 norm of the filters FFT.
  """
    seq_length = tf.shape(inputs)[1]
    fft_length = upper_power_of_2(seq_length)
    filter_length = tf.shape(filters)[0]
    f_inputs = tf.signal.rfft(tf.transpose(inputs, (0, 2, 1)),
                              fft_length=[fft_length])
    f_filters = tf.signal.rfft(tf.transpose(filters, (1, 2, 0)),
                               fft_length=[fft_length])
    f_filters_l1 = tf.reduce_sum(tf.math.abs(f_filters))
    result = tf.transpose(tf.signal.irfft(f_inputs * tf.math.conj(f_filters)),
                          (0, 2, 1))
    output = tf.roll(result, filter_length // 2 - 1, axis=1)
    output = output[:, :seq_length, :]

    shape = tf.concat([tf.shape(inputs)[:-1], tf.shape(filters)[-1:]], axis=0)
    return tf.reshape(output, shape), f_filters_l1
Exemple #5
0
def nearby_difference(x):
    """Compute L2 norms for nearby entries in a batch."""
    # This is a very rough measure of diversity.
    with tf.device('cpu'):
        x1 = tf.reshape(x, shape=[int(x.shape[0]), -1])
        x2 = tf.roll(x1, shift=1, axis=0)
        return tf.sqrt(tf.reduce_sum(tf.square(tf.subtract(x1, x2))))
Exemple #6
0
 def distribution_fn(sample):
   num_frames = sample.shape[-1]
   mask = tf.one_hot(0, num_frames)[:, tf.newaxis]
   probs = tf.roll(tf.one_hot(sample, 3), shift=1, axis=-2)
   probs = probs * (1.0 - mask) + tf.convert_to_tensor([0.5, 0.5, 0]) * mask
   return tfd.Independent(tfd.Categorical(probs=probs),
                          reinterpreted_batch_ndims=1)
Exemple #7
0
def _sample_bates(total_count, low, high, n, seed=None):
  """Vectorized production of `Bates` samples.

  Args:
    total_count: (Batches of) counts of `Uniform`s to take means of.  Should
      have integer dtype and already be broadcasted to the batch shape.
    low: (Batches of) lower bounds of the `Uniform` variables to sample.  Should
      be the same floating dtype as `high` and broadcastable to the batch shape.
    high: (Batches of) upper bounds of the `Uniform` variables to sample. Should
      be the same floating dtype as `low` and broadcastable to the batch shape.
    n: `int32` number of samples to generate.
    seed: Random seed to pass to `Uniform` sampler.

  Returns:
    samples: Samples of (batches of) the `Bates` variable.  Will have same dtype
      as `low` and `high`. If the batch shape is `[B1,..., Bn]`, `samples` has
      shape `[n, B1,..., Bn]`.
  """

  # 1. Sample Uniform(0, 1)s, flattening the batch dimension into axis 0.
  uniform_sample_shape = tf.concat([[tf.reduce_sum(total_count)], [n]], axis=0)
  uniform_samples = samplers.uniform(
      uniform_sample_shape, minval=0., maxval=1., dtype=low.dtype, seed=seed)
  # 2. Produce segment means.
  segment_lengths = tf.reshape(total_count, [-1])
  segment_ids = tf.repeat(tf.range(tf.size(segment_lengths)), segment_lengths)
  flatmeans = tf.math.segment_mean(uniform_samples, segment_ids)
  # 3. Reshape and transpose segment means back to the original shape.
  outshape = tf.concat([tf.shape(total_count), [n]], axis=0)
  tmeans = tf.reshape(flatmeans, outshape)
  axes = tf.range(tf.rank(tmeans))
  means = tf.transpose(tmeans, tf.roll(axes, shift=1, axis=0))
  # 4. Shift/scale from (0, 1) to (low, high).
  return low + (high - low) * means
Exemple #8
0
def one_hot_minus(inputs, shift):
    """Performs (inputs - shift) % vocab_size in the one-hot space.

  Args:
    inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
      Tensor.
    shift: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
      Tensor specifying how much to shift the corresponding one-hot vector in
      inputs. Soft values perform a "weighted shift": for example,
      shift=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * shifting by
      zero; 0.3 * shifting by one; and 0.5 * shifting by two.

  Returns:
    Tensor of same shape and dtype as inputs.
  """
    # TODO(trandustin): Implement with circular conv1d.
    inputs = tf.convert_to_tensor(inputs)
    shift = tf.cast(shift, inputs.dtype)
    vocab_size = inputs.shape[-1]
    if isinstance(vocab_size, tf1.Dimension):
        vocab_size = vocab_size.value
    # Form a [..., vocab_size, vocab_size] matrix. Each batch element of
    # inputs will vector-matrix multiply the vocab_size x vocab_size matrix. This
    # "shifts" the inputs batch element by the corresponding shift batch element.
    shift_matrix = tf.stack(
        [tf.roll(shift, i, axis=-1) for i in range(vocab_size)], axis=-2)
    outputs = tf.einsum('...v,...uv->...u', inputs, shift_matrix)
    return outputs
 def my_fn(ex):
   for feat in ["inputs", "targets"]:
     tokens = ex[feat]
     res = ex.copy()
     n_tokens = tf.size(tokens)
     random_shift = tf.random.uniform([], maxval=n_tokens, dtype=tf.int32)
     res[feat] = tf.roll(tokens, shift=random_shift, axis=0)
   return res
Exemple #10
0
 def grad(dy):
     dshp = tf.shape(x) - tf.shape(dy)
     z = tf.zeros(tf.where(tf.equal(0, dshp), tf.shape(x), dshp),
                  dtype=x.dtype)
     dx = tf.roll(tf.concat([dy, z], axis=axis),
                  shift=begin[axis],
                  axis=axis)
     return dx
Exemple #11
0
def random_token_preprocessor(ex, seed):
  """Selects a random shift to roll the tokens by for each feature."""
  for feat in ["inputs", "targets"]:
    tokens = ex[feat]
    res = ex.copy()
    n_tokens = tf.size(tokens)
    random_shift = tf.random.stateless_uniform(
        [], maxval=n_tokens, dtype=tf.int32, seed=seed)
    res[feat] = tf.roll(tokens, shift=random_shift, axis=0)
  return res
Exemple #12
0
    def prep_fn(image, label):
        """Image preprocessing function."""
        if config.roll_pixels:
            image = tf.roll(image, config.roll_pixels, -2)
        if is_training:
            image = tf.image.random_flip_left_right(image)
            image = tf.pad(image, [[4, 4], [4, 4], [0, 0]])
            image = tf.image.random_crop(image, CIFAR_SHAPE)

        image = tf.image.convert_image_dtype(image, tf.float32)
        return image, label
def _week_day_mappers(weekend_mask):
    """Creates functions to map from ordinals to week days and inverse.

  Creates functions to map from ordinal space (i.e. days since 31 Dec 0) to
  week days. The function assigns the value of 0 to the first non weekend
  day in the week starting on Sunday, 31 Dec 1 through to Saturday, 6 Jan 1 and
  the value assigned to each successive work day is incremented by 1. For a day
  that is not a week day, this count is not incremented from the previous week
  day (hence, multiple ordinal days may have the same week day value).

  Args:
    weekend_mask: A bool `Tensor` of length 7 or None. The weekend mask.

  Returns:
    A tuple of callables.
      `forward`: Takes one `Tensor` argument containing ordinals and returns a
        tuple of two `Tensor`s of the same shape as the input. The first
        `Tensor` is of type `int32` and contains the week day value. The second
        is a bool `Tensor` indicating whether the supplied ordinal was a weekend
        day (i.e. True where the day is a weekend day and False otherwise).
      `backward`: Takes one int32 `Tensor` argument containing week day values
        and returns an int32 `Tensor` containing ordinals for those week days.
  """
    if weekend_mask is None:
        default_forward = lambda x: (x, tf.zeros_like(x, dtype=tf.bool))
        identity = lambda x: x
        return default_forward, identity
    weekend_mask = tf.convert_to_tensor(weekend_mask, dtype=tf.bool)
    weekend_mask = tf.roll(weekend_mask, -_DAYOFWEEK_0, axis=0)
    weekday_mask = tf.logical_not(weekend_mask)
    weekday_offsets = tf.cumsum(tf.cast(weekday_mask, dtype=tf.int32))
    num_workdays = weekday_offsets[-1]
    weekday_offsets -= 1  # Adjust the first workday to index 0.
    ordinal_offsets = tf.convert_to_tensor([0, 1, 2, 3, 4, 5, 6],
                                           dtype=tf.int32)
    ordinal_offsets = ordinal_offsets[weekday_mask]

    def forward(ordinals):
        """Adjusts the ordinals by removing the number of weekend days so far."""
        mod, remainder = ordinals // 7, ordinals % 7
        weekday_values = mod * num_workdays + tf.gather(
            weekday_offsets, remainder)
        is_weekday = tf.gather(weekday_mask, remainder)
        return weekday_values, is_weekday

    def backward(weekday_values):
        """Converts from weekend adjusted values to ordinals."""
        return ((weekday_values // num_workdays) * 7 +
                tf.gather(ordinal_offsets, weekday_values % num_workdays))

    return forward, backward
Exemple #14
0
def _shift_right_by_one(tensor: tf.Tensor, axis: int = -1) -> tf.Tensor:
    """Shift the 1d input tensor to the right by one position without wrapping."""

    if not tensor.dtype.is_integer:
        raise ValueError("Only integer types are supported.")

    # tf.roll wraps around the axis.
    rolled = tf.roll(tensor, shift=1, axis=axis)

    # Zero out the first position by multiplying with [0, 1, 1, ..., 1].
    reverse_onehot = tf.one_hot(0,
                                depth=tensor.shape[axis],
                                on_value=0,
                                off_value=1,
                                dtype=tensor.dtype)

    return rolled * reverse_onehot
            def _do_update(x_update_diff_norm_sq, x_update,
                           hess_matmul_x_update):  # pylint: disable=missing-docstring
                hessian_column_with_l2 = sparse_or_dense_matvecmul(
                    hessian_unregularized_loss_outer,
                    hessian_unregularized_loss_middle *
                    _sparse_or_dense_matmul_onehot(
                        hessian_unregularized_loss_outer, coord),
                    adjoint_a=True)

                if l2_regularizer is not None:
                    hessian_column_with_l2 += _one_hot_like(
                        hessian_column_with_l2,
                        coord,
                        on_value=2. * l2_regularizer)

                # Move the batch dimensions of `hessian_column_with_l2` to rightmost in
                # order to conform to `hess_matmul_x_update`.
                n = tf.rank(hessian_column_with_l2)
                perm = tf.roll(tf.range(n), shift=1, axis=0)
                hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2,
                                                      perm=perm)

                # Update the entire batch at `coord` even if `delta` may be 0 at some
                # batch coordinates. In those cases, adding `delta` is a no-op.
                x_update = tf.tensor_scatter_nd_add(x_update, [[coord]],
                                                    [delta])

                with tf.control_dependencies([x_update]):
                    x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2
                    hess_matmul_x_update_ = (hess_matmul_x_update +
                                             delta * hessian_column_with_l2)

                    # Hint that loop vars retain the same shape.
                    x_update_diff_norm_sq_.set_shape(
                        x_update_diff_norm_sq_.shape.merge_with(
                            x_update_diff_norm_sq.shape))
                    hess_matmul_x_update_.set_shape(
                        hess_matmul_x_update_.shape.merge_with(
                            hess_matmul_x_update.shape))

                    return [
                        x_update_diff_norm_sq_, x_update, hess_matmul_x_update_
                    ]
def crps_score(labels=None, predictive_samples=None):
    r"""Computes the Continuous Ranked Probability Score (CRPS).

  The Continuous Ranked Probability Score is a [proper scoring rule][1] for
  assessing the probabilistic predictions of a model against a realized value.
  The CRPS is

  \\(\textrm{CRPS}(F,y) = \int_{-\inf}^{\inf} (F(z) - 1_{z \geq y})^2 dz.\\)

  Here \\(F\\) is the cumulative distribution function of the model predictive
  distribution and \\(y)\\ is the realized ground truth value.

  The CRPS can be used as a loss function for training an implicit model for
  probabilistic regression.  It can also be used to assess the predictive
  performance of a probabilistic regression model.

  In this implementation we use an equivalent representation of the CRPS,

  \\(\textrm{CRPS}(F,y) = E_{z~F}[|z-y|] - (1/2) E_{z,z'~F}[|z-z'|].\\)

  This equivalent representation has an unbiased sample estimate and our
  implementation of the CRPS has a complexity is O(n m).

  #### References
  [1]: Tilmann Gneiting, Adrian E. Raftery.
       Strictly Proper Scoring Rules, Prediction, and Estimation.
       Journal of the American Statistical Association, Vol. 102, 2007.
       https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf

  Args:
    labels: Tensor, (n,), with tf.float32 or tf.float64 real-valued targets.
    predictive_samples: Tensor, (n,m), with tf.float32 or tf.float64 values.
      Each row at [i,:] contains m samples of the model predictive
      distribution, p(y|x_i).

  Returns:
    crps: (n,) Tensor, the CRPS score for each instance; a lower score
      indicates a better fit.
  """
    if labels is None:
        raise ValueError("target labels must be provided")
    if labels.shape.ndims != 1:
        raise ValueError("target labels must be of rank 1")

    if predictive_samples is None:
        raise ValueError("predictive samples must be provided")
    if predictive_samples.shape.ndims != 2:
        raise ValueError("predictive samples must be of rank 2")
    if predictive_samples.shape[0] != labels.shape[0]:
        raise ValueError("first dimension of predictive samples shape "
                         "must match target labels shape")

    pairwise_diff = tf.roll(predictive_samples, 1, axis=1) - predictive_samples
    predictive_diff = tf.abs(pairwise_diff)
    estimated_dist_pairwise = tf.reduce_mean(input_tensor=predictive_diff,
                                             axis=1)

    labels = tf.expand_dims(labels, 1)
    dist_realization = tf.reduce_mean(tf.abs(predictive_samples - labels),
                                      axis=1)

    crps = dist_realization - 0.5 * estimated_dist_pairwise

    return crps
Exemple #17
0
 def rollaxis(x, shift):
   return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0))
Exemple #18
0
    def _parse_fn(record):
        """Parses a record into a feature_dict."""
        feature_values = tf.io.parse_single_example(
            serialized=record,
            features={
                'i/o':
                tf.io.FixedLenFeature([], tf.string, default_value=''),
                'program_encoding':
                tf.io.FixedLenFeature([], tf.string, default_value=''),
            })

        ios = tf.strings.split(tf.strings.split(feature_values['i/o'],
                                                sep='>'),
                               sep='<')

        inputs, outputs = ios.merge_dims(0, 1)[::2], ios.merge_dims(0, 1)[1::2]

        # Parse inputs into tokens.
        inputs = tf.strings.unicode_split(inputs, 'UTF-8').to_tensor()
        inputs = spec_vocab_table.lookup(inputs)  # Map characters to tokens.

        # Parse outputs into tokens.
        outputs_with_separators = (tf.strings.unicode_split(
            outputs, 'UTF-8').to_tensor())
        outputs_with_separators = spec_vocab_table.lookup(
            outputs_with_separators)
        split_outputs = tf.strings.unicode_split(
            tf.strings.split(outputs, sep='|'), 'UTF-8')
        outputs = split_outputs.merge_dims(1, 2).to_tensor()
        outputs = spec_vocab_table.lookup(outputs)

        # Compute indices for the start of each part of the spec, w.r.t. the
        # original spec.
        separator_indices = tf.where(
            tf.equal(outputs_with_separators, separator_id))[:, 1]
        separator_indices = tf.reshape(
            separator_indices, (tf.shape(outputs_with_separators)[0], -1))
        start_indices = separator_indices - tf.expand_dims(
            tf.range(tf.shape(separator_indices)[1], dtype=tf.int64), 0)
        start_indices = tf.concat((tf.zeros(
            (tf.shape(start_indices)[0], 1), dtype=tf.int64), start_indices),
                                  axis=1)

        num_examples = tf.shape(start_indices)[0]
        num_parts = tf.shape(start_indices)[1]

        # Construct the shifted spec suffixes.
        flat_start_indices = tf.reshape(start_indices, (-1, ))
        prefix_mask = (1 - tf.sequence_mask(
            flat_start_indices, maxlen=tf.shape(outputs)[-1], dtype=tf.int64))
        masked_outputs = tf.repeat(outputs, num_parts, axis=0) * prefix_mask
        output_suffixes = tf.vectorized_map(
            fn=lambda x: tf.roll(x[0], x[1], axis=0),
            elems=(masked_outputs, -flat_start_indices))

        # Compute indices for the start/end of spec parts, w.r.t. the shifted spec
        # suffixes.
        ground_truth_start_indices = tf.zeros((num_examples * num_parts, ),
                                              dtype=tf.int64)
        cumulative_end_indices = tf.concat(
            (start_indices,
             tf.math.count_nonzero(outputs, axis=-1, keepdims=True)),
            axis=1)
        ground_truth_end_indices = tf.reshape(
            cumulative_end_indices[:, 1:] - cumulative_end_indices[:, :-1],
            (-1, ))

        # Construct the actual spec parts to predict.
        range_indices = tf.expand_dims(tf.range(tf.shape(output_suffixes)[-1],
                                                dtype=tf.int64),
                                       axis=0)
        part_mask = tf.where(
            tf.logical_and(
                range_indices >= tf.expand_dims(ground_truth_start_indices,
                                                axis=1),
                range_indices < tf.expand_dims(ground_truth_end_indices,
                                               axis=1)), 1, 0)
        output_parts = output_suffixes * tf.cast(part_mask, tf.int64)
        output_parts = tf.pad(output_parts,
                              [[0, 0], [0, 1]])  # Make room for sep.
        # TODO(kshi): roll output_parts leftward by start_indices for SCAN.
        first_zero_index = tf.math.count_nonzero(output_parts, axis=-1)
        output_parts += tf.one_hot(first_zero_index,
                                   depth=tf.shape(output_parts)[-1],
                                   dtype=tf.int64) * separator_id

        # Reshape everything so that different spec suffixes become different
        # dataset elements.
        output_suffixes_reshaped = tf.transpose(
            tf.reshape(output_suffixes, (num_examples, num_parts, -1)),
            (1, 0, 2))
        output_parts_reshaped = tf.transpose(
            tf.reshape(output_parts, (num_examples, num_parts, -1)), (1, 0, 2))
        inputs_reshaped = tf.reshape(tf.tile(inputs, (num_parts, 1)),
                                     (num_parts, num_examples, -1))
        ground_truth_start_indices_reshaped = tf.transpose(
            tf.reshape(ground_truth_start_indices, (num_examples, num_parts)))
        ground_truth_end_indices_reshaped = tf.transpose(
            tf.reshape(ground_truth_end_indices, (num_examples, num_parts)))

        # Combine spec parts from all examples into one sequence with separator
        # tokens between examples and ending in EOS.
        shifts = tf.cumsum(tf.concat((tf.zeros(
            (num_parts, 1),
            dtype=tf.int64), ground_truth_end_indices_reshaped[:, :-1] + 1),
                                     1),
                           axis=-1)
        flat_shifts = tf.reshape(shifts, (-1, ))
        output_len = tf.shape(output_parts_reshaped)[-1]
        flat_spec_parts = tf.reshape(output_parts_reshaped, (-1, output_len))
        flat_spec_parts = tf.pad(flat_spec_parts,
                                 [[0, 0], [0, max_target_length - output_len]])
        combined_spec_parts = tf.vectorized_map(
            fn=lambda x: tf.roll(x[0], x[1], axis=0),
            elems=(flat_spec_parts, flat_shifts))
        combined_spec_parts = tf.reshape(combined_spec_parts,
                                         (num_parts, num_examples, -1))
        combined_spec_parts = tf.reduce_sum(combined_spec_parts, axis=1)
        first_zero_index = tf.math.count_nonzero(combined_spec_parts, axis=-1)
        combined_spec_parts += tf.one_hot(
            first_zero_index,
            depth=tf.shape(combined_spec_parts)[-1],
            dtype=tf.int64) * eos_id

        # Create a dataset containing data for all spec suffixes.
        dataset = tf.data.Dataset.from_tensor_slices({
            'inputs':
            inputs_reshaped,
            'outputs':
            output_suffixes_reshaped,
            'spec_parts':
            combined_spec_parts,
            'start_index':
            ground_truth_start_indices_reshaped,
            'end_index':
            ground_truth_end_indices_reshaped
        })
        return dataset
Exemple #19
0
 def roll_fn(image, label):
     """Function to roll pixels."""
     image = tf.roll(image, config.roll_pixels, -2)
     return image, label
def minimize_one_step(gradient_unregularized_loss,
                      hessian_unregularized_loss_outer,
                      hessian_unregularized_loss_middle,
                      x_start,
                      tolerance,
                      l1_regularizer,
                      l2_regularizer=None,
                      maximum_full_sweeps=1,
                      learning_rate=None,
                      name=None):
  """One step of (the outer loop of) the minimization algorithm.

  This function returns a new value of `x`, equal to `x_start + x_update`.  The
  increment `x_update in R^n` is computed by a coordinate descent method, that
  is, by a loop in which each iteration updates exactly one coordinate of
  `x_update`.  (Some updates may leave the value of the coordinate unchanged.)

  The particular update method used is to apply an L1-based proximity operator,
  "soft threshold", whose fixed point `x_update_fix` is the desired minimum

  ```none
  x_update_fix = argmin{
      Loss(x_start + x_update')
        + l1_regularizer * ||x_start + x_update'||_1
        + l2_regularizer * ||x_start + x_update'||_2**2
      : x_update' }
  ```

  where in each iteration `x_update'` is constrained to have at most one nonzero
  coordinate.

  This update method preserves sparsity, i.e., tends to find sparse solutions if
  `x_start` is sparse.  Additionally, the choice of step size is based on
  curvature (Hessian), which significantly speeds up convergence.

  This algorithm assumes that `Loss` is convex, at least in a region surrounding
  the optimum.  (If `l2_regularizer > 0`, then only weak convexity is needed.)

  Args:
    gradient_unregularized_loss: (Batch of) `Tensor` with the same shape and
      dtype as `x_start` representing the gradient, evaluated at `x_start`, of
      the unregularized loss function (denoted `Loss` above).  (In all current
      use cases, `Loss` is the negative log likelihood.)
    hessian_unregularized_loss_outer: (Batch of) `Tensor` or `SparseTensor`
      having the same dtype as `x_start`, and shape `[N, n]` where `x_start` has
      shape `[n]`, satisfying the property
      `Transpose(hessian_unregularized_loss_outer)
      @ diag(hessian_unregularized_loss_middle)
      @ hessian_unregularized_loss_inner
      = (approximation of) Hessian matrix of Loss, evaluated at x_start`.
    hessian_unregularized_loss_middle: (Batch of) vector-shaped `Tensor` having
      the same dtype as `x_start`, and shape `[N]` where
      `hessian_unregularized_loss_outer` has shape `[N, n]`, satisfying the
      property
      `Transpose(hessian_unregularized_loss_outer)
      @ diag(hessian_unregularized_loss_middle)
      @ hessian_unregularized_loss_inner
      = (approximation of) Hessian matrix of Loss, evaluated at x_start`.
    x_start: (Batch of) vector-shaped, `float` `Tensor` representing the current
      value of the argument to the Loss function.
    tolerance: scalar, `float` `Tensor` representing the convergence threshold.
      The optimization step will terminate early, returning its current value of
      `x_start + x_update`, once the following condition is met:
      `||x_update_end - x_update_start||_2 / (1 + ||x_start||_2)
      < sqrt(tolerance)`,
      where `x_update_end` is the value of `x_update` at the end of a sweep and
      `x_update_start` is the value of `x_update` at the beginning of that
      sweep.
    l1_regularizer: scalar, `float` `Tensor` representing the weight of the L1
      regularization term (see equation above).  If L1 regularization is not
      required, then `tfp.glm.fit_one_step` is preferable.
    l2_regularizer: scalar, `float` `Tensor` representing the weight of the L2
      regularization term (see equation above).
      Default value: `None` (i.e., no L2 regularization).
    maximum_full_sweeps: Python integer specifying maximum number of sweeps to
      run.  A "sweep" consists of an iteration of coordinate descent on each
      coordinate. After this many sweeps, the algorithm will terminate even if
      convergence has not been reached.
      Default value: `1`.
    learning_rate: scalar, `float` `Tensor` representing a multiplicative factor
      used to dampen the proximal gradient descent steps.
      Default value: `None` (i.e., factor is conceptually `1`).
    name: Python string representing the name of the TensorFlow operation.
      The default name is `"minimize_one_step"`.

  Returns:
    x: (Batch of) `Tensor` having the same shape and dtype as `x_start`,
      representing the updated value of `x`, that is, `x_start + x_update`.
    is_converged: scalar, `bool` `Tensor` indicating whether convergence
      occurred across all batches within the specified number of sweeps.
    iter: scalar, `int` `Tensor` representing the actual number of coordinate
      updates made (before achieving convergence).  Since each sweep consists of
      `tf.size(x_start)` iterations, the maximum number of updates is
      `maximum_full_sweeps * tf.size(x_start)`.

  #### References

  [1]: Jerome Friedman, Trevor Hastie and Rob Tibshirani. Regularization Paths
       for Generalized Linear Models via Coordinate Descent. _Journal of
       Statistical Software_, 33(1), 2010.
       https://www.jstatsoft.org/article/view/v033i01/v33i01.pdf

  [2]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for
       L1-regularized Logistic Regression. _Journal of Machine Learning
       Research_, 13, 2012.
       http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf
  """
  with tf.name_scope(name or 'minimize_one_step'):
    x_shape = _get_shape(x_start)
    batch_shape = x_shape[:-1]
    dims = x_shape[-1]

    def _hessian_diag_elt_with_l2(coord):  # pylint: disable=missing-docstring
      # Returns the (coord, coord) entry of
      #
      #   Hessian(UnregularizedLoss(x) + l2_regularizer * ||x||_2**2)
      #
      # evaluated at x = x_start.
      inner_square = tf.reduce_sum(
          _sparse_or_dense_matmul_onehot(
              hessian_unregularized_loss_outer, coord)**2,
          axis=-1)
      unregularized_component = inner_square * tf.gather(
          hessian_unregularized_loss_middle, coord, axis=-1)
      l2_component = _mul_or_none(2., l2_regularizer)
      return _add_ignoring_nones(unregularized_component, l2_component)

    grad_loss_with_l2 = _add_ignoring_nones(
        gradient_unregularized_loss, _mul_or_none(2., l2_regularizer, x_start))

    # We define `x_update_diff_norm_sq_convergence_threshold` such that the
    # convergence condition
    #     ||x_update_end - x_update_start||_2 / (1 + ||x_start||_2)
    #     < sqrt(tolerance)
    # is equivalent to
    #     ||x_update_end - x_update_start||_2**2
    #     < x_update_diff_norm_sq_convergence_threshold.
    x_update_diff_norm_sq_convergence_threshold = (
        tolerance * (1. + tf.norm(tensor=x_start, ord=2, axis=-1))**2)

    # Reshape update vectors so that the coordinate sweeps happen along the
    # first dimension. This is so that we can use tensor_scatter_update to make
    # sparse updates along the first axis without copying the Tensor.
    # TODO(b/118789120): Switch to something like tf.tensor_scatter_nd_add if
    # or when it exists.
    update_shape = tf.concat([[dims], batch_shape], axis=-1)

    def _loop_cond(iter_, x_update_diff_norm_sq, x_update,
                   hess_matmul_x_update):
      del x_update
      del hess_matmul_x_update
      sweep_complete = (iter_ > 0) & tf.equal(iter_ % dims, 0)
      small_delta = (
          x_update_diff_norm_sq < x_update_diff_norm_sq_convergence_threshold)
      converged = sweep_complete & small_delta
      allowed_more_iterations = iter_ < maximum_full_sweeps * dims
      return allowed_more_iterations & tf.reduce_any(~converged)

    def _loop_body(  # pylint: disable=missing-docstring
        iter_, x_update_diff_norm_sq, x_update, hess_matmul_x_update):
      # Inner loop of the minimizer.
      #
      # This loop updates a single coordinate of x_update.  Ideally, an
      # iteration of this loop would set
      #
      #   x_update[j] += argmin{ LocalLoss(x_update + z*e_j) : z in R }
      #
      # where
      #
      #   LocalLoss(x_update')
      #     = LocalLossSmoothComponent(x_update')
      #         + l1_regularizer * (||x_start + x_update'||_1 -
      #                             ||x_start + x_update||_1)
      #    := (UnregularizedLoss(x_start + x_update') -
      #        UnregularizedLoss(x_start + x_update)
      #         + l2_regularizer * (||x_start + x_update'||_2**2 -
      #                             ||x_start + x_update||_2**2)
      #         + l1_regularizer * (||x_start + x_update'||_1 -
      #                             ||x_start + x_update||_1)
      #
      # In this algorithm approximate the above argmin using (univariate)
      # proximal gradient descent:
      #
      # (*)  x_update[j] = prox_{t * l1_regularizer * L1}(
      #                 x_update[j] -
      #                 t * d/dz|z=0 UnivariateLocalLossSmoothComponent(z))
      #
      # where
      #
      #   UnivariateLocalLossSmoothComponent(z)
      #       := LocalLossSmoothComponent(x_update + z*e_j)
      #
      # and we approximate
      #
      #       d/dz UnivariateLocalLossSmoothComponent(z)
      #     = grad LocalLossSmoothComponent(x_update))[j]
      #    ~= (grad LossSmoothComponent(x_start)
      #         + x_update matmul HessianOfLossSmoothComponent(x_start))[j].
      #
      # To choose the parameter t, we squint and pretend that the inner term of
      # (*) is a Newton update as if we were using Newton's method to minimize
      # UnivariateLocalLossSmoothComponent.  That is, we choose t such that
      #
      #   -t * d/dz ULLSC = -learning_rate * (d/dz ULLSC) / (d^2/dz^2 ULLSC)
      #
      # at z=0.  Hence
      #
      #   t = learning_rate / (d^2/dz^2|z=0 ULLSC)
      #     = learning_rate / HessianOfLossSmoothComponent(
      #                           x_start + x_update)[j,j]
      #    ~= learning_rate / HessianOfLossSmoothComponent(
      #                           x_start)[j,j]
      #
      # The above approximation is equivalent to assuming that
      # HessianOfUnregularizedLoss is constant, i.e., ignoring third-order
      # effects.
      #
      # Note that because LossSmoothComponent is (assumed to be) convex, t is
      # positive.

      # In above notation, coord = j.
      coord = iter_ % dims
      # x_update_diff_norm_sq := ||x_update_end - x_update_start||_2**2,
      # computed incrementally, where x_update_end and x_update_start are as
      # defined in the convergence criteria.  Accordingly, we reset
      # x_update_diff_norm_sq to zero at the beginning of each sweep.
      x_update_diff_norm_sq = tf.where(
          tf.equal(coord, 0),
          dtype_util.as_numpy_dtype(x_update_diff_norm_sq.dtype)(0.),
          x_update_diff_norm_sq)

      # Recall that x_update and hess_matmul_x_update has the rightmost
      # dimension transposed to the leftmost dimension.
      w_old = (tf.gather(x_start, coord, axis=-1) +
               tf.gather(x_update, coord, axis=0))
      # This is the coordinatewise Newton update if no L1 regularization.
      # In above notation, newton_step = -t * (approximation of d/dz|z=0 ULLSC).
      second_deriv = _hessian_diag_elt_with_l2(coord)
      newton_step = -_mul_ignoring_nones(  # pylint: disable=invalid-unary-operand-type
          learning_rate,
          (tf.gather(grad_loss_with_l2, coord, axis=-1) +
           tf.gather(hess_matmul_x_update, coord, axis=0))) / second_deriv

      # Applying the soft-threshold operator accounts for L1 regularization.
      # In above notation, delta =
      #     prox_{t*l1_regularizer*L1}(w_old + newton_step) - w_old.
      delta = (
          soft_threshold(
              w_old + newton_step,
              _mul_ignoring_nones(learning_rate, l1_regularizer) / second_deriv)
          - w_old)

      def _do_update(x_update_diff_norm_sq, x_update, hess_matmul_x_update):  # pylint: disable=missing-docstring
        hessian_column_with_l2 = sparse_or_dense_matvecmul(
            hessian_unregularized_loss_outer,
            hessian_unregularized_loss_middle * _sparse_or_dense_matmul_onehot(
                hessian_unregularized_loss_outer, coord),
            adjoint_a=True)

        if l2_regularizer is not None:
          hessian_column_with_l2 += _one_hot_like(
              hessian_column_with_l2, coord, on_value=2. * l2_regularizer)

        # Move the batch dimensions of `hessian_column_with_l2` to rightmost in
        # order to conform to `hess_matmul_x_update`.
        n = tf.rank(hessian_column_with_l2)
        perm = tf.roll(tf.range(n), shift=1, axis=0)
        hessian_column_with_l2 = tf.transpose(
            a=hessian_column_with_l2, perm=perm)

        # Update the entire batch at `coord` even if `delta` may be 0 at some
        # batch coordinates. In those cases, adding `delta` is a no-op.
        x_update = tf.tensor_scatter_nd_add(x_update, [[coord]], [delta])

        with tf.control_dependencies([x_update]):
          x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2
          hess_matmul_x_update_ = (
              hess_matmul_x_update + delta * hessian_column_with_l2)

          # Hint that loop vars retain the same shape.
          x_update_diff_norm_sq_.set_shape(
              x_update_diff_norm_sq_.shape.merge_with(
                  x_update_diff_norm_sq.shape))
          hess_matmul_x_update_.set_shape(
              hess_matmul_x_update_.shape.merge_with(
                  hess_matmul_x_update.shape))

          return [x_update_diff_norm_sq_, x_update, hess_matmul_x_update_]

      inputs_to_update = [x_update_diff_norm_sq, x_update, hess_matmul_x_update]
      return [iter_ + 1] + prefer_static.cond(
          # Note on why checking delta (a difference of floats) for equality to
          # zero is ok:
          #
          # First of all, x - x == 0 in floating point -- see
          # https://stackoverflow.com/a/2686671
          #
          # Delta will conceptually equal zero when one of the following holds:
          # (i)   |w_old + newton_step| <= threshold and w_old == 0
          # (ii)  |w_old + newton_step| > threshold and
          #       w_old + newton_step - sign(w_old + newton_step) * threshold
          #          == w_old
          #
          # In case (i) comparing delta to zero is fine.
          #
          # In case (ii), newton_step conceptually equals
          #     sign(w_old + newton_step) * threshold.
          # Also remember
          #     threshold = -newton_step / (approximation of d/dz|z=0 ULLSC).
          # So (i) happens when
          #     (approximation of d/dz|z=0 ULLSC) == -sign(w_old + newton_step).
          # If we did not require LossSmoothComponent to be strictly convex,
          # then this could actually happen a non-negligible amount of the time,
          # e.g. if the loss function is piecewise linear and one of the pieces
          # has slope 1.  But since LossSmoothComponent is strictly convex, (i)
          # should not systematically happen.
          tf.reduce_all(tf.equal(delta, 0.)),
          lambda: inputs_to_update,
          lambda: _do_update(*inputs_to_update))

    base_dtype = x_start.dtype.base_dtype
    iter_, x_update_diff_norm_sq, x_update, _ = tf.while_loop(
        cond=_loop_cond,
        body=_loop_body,
        loop_vars=[
            tf.zeros([], dtype=np.int32, name='iter'),
            tf.zeros(
                batch_shape, dtype=base_dtype, name='x_update_diff_norm_sq'),
            tf.zeros(update_shape, dtype=base_dtype, name='x_update'),
            tf.zeros(
                update_shape, dtype=base_dtype, name='hess_matmul_x_update'),
        ])

    # Convert back x_update to the shape of x_start by transposing the leftmost
    # dimension to the rightmost.
    n = tf.rank(x_update)
    perm = tf.roll(tf.range(n), shift=-1, axis=0)
    x_update = tf.transpose(a=x_update, perm=perm)

    converged = tf.reduce_all(x_update_diff_norm_sq <
                              x_update_diff_norm_sq_convergence_threshold)
    return x_start + x_update, converged, iter_ / dims