def get_effective_mask(self):
    if self.round_mask:
        # during train, clamp a random 50% to their rounded values, and backprop the other 50% directly
        # during test, clamp all of them to their rounded values
        which_to_clamp = tf.cond(
            learning_phase(), lambda: gen_math_ops.round(
                tf.random.uniform(self.kernel_mask.shape, minval=0, maxval=1)),
            lambda: tf.ones(self.kernel_mask.shape))
        binary_mask = gen_math_ops.round(tf.nn.sigmoid(self.kernel_mask))
    else:
        # during train, clamp all of them to 0's and 1's sampled by bernoulli and backprop the probabilities
        # during test, clamp all of them to their rounded values
        # actually, sample them too
        which_to_clamp = tf.ones(self.kernel_mask.shape)
        binary_mask = tf.cond(
            learning_phase(),
            lambda: tf.cast(tf.distributions.Bernoulli(probs=tf.nn.sigmoid(
                self.kernel_mask)).sample(),
                            dtype=tf.float32) + tf.nn.sigmoid(self.kernel_mask)
            - tf.stop_gradient(tf.nn.sigmoid(self.kernel_mask)),
            lambda: tf.cast(tf.distributions.Bernoulli(probs=tf.nn.sigmoid(
                self.kernel_mask)).sample(),
                            dtype=tf.float32))

    return which_to_clamp * binary_mask + (1 - which_to_clamp) * tf.nn.sigmoid(
        self.kernel_mask)
Пример #2
0
 def _update_statistics_from_mini_batch(
     self, statistics, auxiliary_variables, times, values):
   """Given mini-batch input, update `statistics` and `auxiliary_variables`."""
   values = math_ops.cast(values, self._dtype)
   # The density (measured in times per observation) that we see in each part
   # of the mini-batch.
   batch_inter_observation_duration = (math_ops.cast(
       math_ops.reduce_max(times, axis=1) - math_ops.reduce_min(times, axis=1),
       self._dtype) / math_ops.cast(
           array_ops.shape(times)[1] - 1, self._dtype))
   # Co-locate updates with their variables to minimize race conditions when
   # updating statistics.
   with ops.colocate_with(auxiliary_variables.max_time_seen):
     # There is a race condition if this value is being updated from multiple
     # workers. However, it should eventually reach the correct value if the
     # last chunk is presented enough times.
     max_time_seen_assign = state_ops.assign(
         auxiliary_variables.max_time_seen,
         gen_math_ops.maximum(auxiliary_variables.max_time_seen,
                              math_ops.reduce_max(times)))
   with ops.colocate_with(auxiliary_variables.chunk_count):
     chunk_count_assign = state_ops.assign_add(auxiliary_variables.chunk_count,
                                               array_ops.shape(
                                                   times,
                                                   out_type=dtypes.int64)[0])
   with ops.colocate_with(auxiliary_variables.inter_observation_duration_sum):
     inter_observation_duration_assign = state_ops.assign_add(
         auxiliary_variables.inter_observation_duration_sum,
         math_ops.reduce_sum(batch_inter_observation_duration))
   with ops.colocate_with(auxiliary_variables.example_count):
     example_count_assign = state_ops.assign_add(
         auxiliary_variables.example_count,
         array_ops.size(times, out_type=dtypes.int64))
   # Note: These mean/variance updates assume that all points are equally
   # likely, which is not true if _chunks_ are sampled uniformly from the space
   # of all possible contiguous chunks, since points at the start and end of
   # the series are then members of fewer chunks. For series which are much
   # longer than the chunk size (the usual/expected case), this effect becomes
   # irrelevant.
   with ops.colocate_with(auxiliary_variables.overall_feature_sum):
     overall_feature_sum_assign = state_ops.assign_add(
         auxiliary_variables.overall_feature_sum,
         math_ops.reduce_sum(values, axis=[0, 1]))
   with ops.colocate_with(auxiliary_variables.overall_feature_sum_of_squares):
     overall_feature_sum_of_squares_assign = state_ops.assign_add(
         auxiliary_variables.overall_feature_sum_of_squares,
         math_ops.reduce_sum(values**2, axis=[0, 1]))
   per_chunk_aux_updates = control_flow_ops.group(
       max_time_seen_assign, chunk_count_assign,
       inter_observation_duration_assign, example_count_assign,
       overall_feature_sum_assign, overall_feature_sum_of_squares_assign)
   with ops.control_dependencies([per_chunk_aux_updates]):
     example_count_float = math_ops.cast(auxiliary_variables.example_count,
                                         self._dtype)
     new_feature_mean = (auxiliary_variables.overall_feature_sum /
                         example_count_float)
     overall_feature_mean_update = state_ops.assign(
         statistics.overall_feature_moments.mean, new_feature_mean)
     overall_feature_var_update = state_ops.assign(
         statistics.overall_feature_moments.variance,
         # De-biased n / (n - 1) variance correction
         example_count_float / (example_count_float - 1.) *
         (auxiliary_variables.overall_feature_sum_of_squares /
          example_count_float - new_feature_mean**2))
     # TODO(b/35675805): Remove this cast
     min_time_batch = math_ops.cast(math_ops.argmin(times[:, 0]), dtypes.int32)
     def series_start_updates():
       # If this is the lowest-time chunk that we have seen so far, update
       # series start moments to reflect that. Note that these statistics are
       # "best effort", as there are race conditions in the update (however,
       # they should eventually converge if the start of the series is
       # presented enough times).
       mean, variance = nn.moments(
           values[min_time_batch, :self._starting_variance_window_size],
           axes=[0])
       return control_flow_ops.group(
           state_ops.assign(statistics.series_start_moments.mean, mean),
           state_ops.assign(statistics.series_start_moments.variance,
                            variance))
     with ops.colocate_with(statistics.start_time):
       series_start_update = control_flow_ops.cond(
           # Update moments whenever we even match the lowest time seen so far,
           # to ensure that series start statistics are eventually updated to
           # their correct values, despite race conditions (i.e. eventually
           # statistics.start_time will reflect the global lowest time, and
           # given that we will eventually update the series start moments to
           # their correct values).
           math_ops.less_equal(times[min_time_batch, 0],
                               statistics.start_time),
           series_start_updates,
           control_flow_ops.no_op)
       with ops.control_dependencies([series_start_update]):
         # There is a race condition if this update is performed in parallel on
         # multiple workers. Since models may be sensitive to being presented
         # with times before the putative start time, the value of this
         # variable is post-processed above to guarantee that each worker is
         # presented with a start time which is at least as low as the lowest
         # time in its current mini-batch.
         start_time_update = state_ops.assign(statistics.start_time,
                                              gen_math_ops.minimum(
                                                  statistics.start_time,
                                                  math_ops.reduce_min(times)))
     inter_observation_duration_estimate = (
         auxiliary_variables.inter_observation_duration_sum / math_ops.cast(
             auxiliary_variables.chunk_count, self._dtype))
     # Estimate the total number of observations as:
     #   (end time - start time + 1) * average intra-chunk time density
     total_observation_count_update = state_ops.assign(
         statistics.total_observation_count,
         math_ops.cast(
             gen_math_ops.round(
                 math_ops.cast(auxiliary_variables.max_time_seen -
                               statistics.start_time + 1, self._dtype) /
                 inter_observation_duration_estimate), dtypes.int64))
     per_chunk_stat_updates = control_flow_ops.group(
         overall_feature_mean_update, overall_feature_var_update,
         series_start_update, start_time_update,
         total_observation_count_update)
   return per_chunk_stat_updates
Пример #3
0
    def _update_statistics_from_mini_batch(self, statistics,
                                           auxiliary_variables, times, values):
        """Given mini-batch input, update `statistics` and `auxiliary_variables`."""
        values = math_ops.cast(values, self._dtype)
        # The density (measured in times per observation) that we see in each part
        # of the mini-batch.
        batch_inter_observation_duration = (
            math_ops.cast(
                math_ops.reduce_max(times, axis=1) -
                math_ops.reduce_min(times, axis=1), self._dtype) /
            math_ops.cast(array_ops.shape(times)[1] - 1, self._dtype))
        # Co-locate updates with their variables to minimize race conditions when
        # updating statistics.
        with ops.colocate_with(auxiliary_variables.max_time_seen):
            # There is a race condition if this value is being updated from multiple
            # workers. However, it should eventually reach the correct value if the
            # last chunk is presented enough times.
            max_time_seen_assign = state_ops.assign(
                auxiliary_variables.max_time_seen,
                gen_math_ops.maximum(auxiliary_variables.max_time_seen,
                                     math_ops.reduce_max(times)))
        with ops.colocate_with(auxiliary_variables.chunk_count):
            chunk_count_assign = state_ops.assign_add(
                auxiliary_variables.chunk_count,
                array_ops.shape(times, out_type=dtypes.int64)[0])
        with ops.colocate_with(
                auxiliary_variables.inter_observation_duration_sum):
            inter_observation_duration_assign = state_ops.assign_add(
                auxiliary_variables.inter_observation_duration_sum,
                math_ops.reduce_sum(batch_inter_observation_duration))
        with ops.colocate_with(auxiliary_variables.example_count):
            example_count_assign = state_ops.assign_add(
                auxiliary_variables.example_count,
                array_ops.size(times, out_type=dtypes.int64))
        # Note: These mean/variance updates assume that all points are equally
        # likely, which is not true if _chunks_ are sampled uniformly from the space
        # of all possible contiguous chunks, since points at the start and end of
        # the series are then members of fewer chunks. For series which are much
        # longer than the chunk size (the usual/expected case), this effect becomes
        # irrelevant.
        with ops.colocate_with(auxiliary_variables.overall_feature_sum):
            overall_feature_sum_assign = state_ops.assign_add(
                auxiliary_variables.overall_feature_sum,
                math_ops.reduce_sum(values, axis=[0, 1]))
        with ops.colocate_with(
                auxiliary_variables.overall_feature_sum_of_squares):
            overall_feature_sum_of_squares_assign = state_ops.assign_add(
                auxiliary_variables.overall_feature_sum_of_squares,
                math_ops.reduce_sum(values**2, axis=[0, 1]))
        per_chunk_aux_updates = control_flow_ops.group(
            max_time_seen_assign, chunk_count_assign,
            inter_observation_duration_assign, example_count_assign,
            overall_feature_sum_assign, overall_feature_sum_of_squares_assign)
        with ops.control_dependencies([per_chunk_aux_updates]):
            example_count_float = math_ops.cast(
                auxiliary_variables.example_count, self._dtype)
            new_feature_mean = (auxiliary_variables.overall_feature_sum /
                                example_count_float)
            overall_feature_mean_update = state_ops.assign(
                statistics.overall_feature_moments.mean, new_feature_mean)
            overall_feature_var_update = state_ops.assign(
                statistics.overall_feature_moments.variance,
                # De-biased n / (n - 1) variance correction
                example_count_float / (example_count_float - 1.) *
                (auxiliary_variables.overall_feature_sum_of_squares /
                 example_count_float - new_feature_mean**2))
            # TODO(b/35675805): Remove this cast
            min_time_batch = math_ops.cast(math_ops.argmin(times[:, 0]),
                                           dtypes.int32)

            def series_start_updates():
                # If this is the lowest-time chunk that we have seen so far, update
                # series start moments to reflect that. Note that these statistics are
                # "best effort", as there are race conditions in the update (however,
                # they should eventually converge if the start of the series is
                # presented enough times).
                mean, variance = nn.moments(values[
                    min_time_batch, :self._starting_variance_window_size],
                                            axes=[0])
                return control_flow_ops.group(
                    state_ops.assign(statistics.series_start_moments.mean,
                                     mean),
                    state_ops.assign(statistics.series_start_moments.variance,
                                     variance))

            with ops.colocate_with(statistics.start_time):
                series_start_update = control_flow_ops.cond(
                    # Update moments whenever we even match the lowest time seen so far,
                    # to ensure that series start statistics are eventually updated to
                    # their correct values, despite race conditions (i.e. eventually
                    # statistics.start_time will reflect the global lowest time, and
                    # given that we will eventually update the series start moments to
                    # their correct values).
                    math_ops.less_equal(times[min_time_batch, 0],
                                        statistics.start_time),
                    series_start_updates,
                    control_flow_ops.no_op)
                with ops.control_dependencies([series_start_update]):
                    # There is a race condition if this update is performed in parallel on
                    # multiple workers. Since models may be sensitive to being presented
                    # with times before the putative start time, the value of this
                    # variable is post-processed above to guarantee that each worker is
                    # presented with a start time which is at least as low as the lowest
                    # time in its current mini-batch.
                    start_time_update = state_ops.assign(
                        statistics.start_time,
                        gen_math_ops.minimum(statistics.start_time,
                                             math_ops.reduce_min(times)))
            inter_observation_duration_estimate = (
                auxiliary_variables.inter_observation_duration_sum /
                math_ops.cast(auxiliary_variables.chunk_count, self._dtype))
            # Estimate the total number of observations as:
            #   (end time - start time + 1) * average intra-chunk time density
            total_observation_count_update = state_ops.assign(
                statistics.total_observation_count,
                math_ops.cast(
                    gen_math_ops.round(
                        math_ops.cast(
                            auxiliary_variables.max_time_seen -
                            statistics.start_time + 1, self._dtype) /
                        inter_observation_duration_estimate), dtypes.int64))
            per_chunk_stat_updates = control_flow_ops.group(
                overall_feature_mean_update, overall_feature_var_update,
                series_start_update, start_time_update,
                total_observation_count_update)
        return per_chunk_stat_updates
Пример #4
0
    def transition_to_powers(self, powers):
        """Computes TransitionMatrix^power efficiently.

    For an n x n transition matrix we have:

      (TransitionMatrix**power)_{i, j) = (-1) ** i * sin(pi * power) / (n + 1)
          * ((-1) ** j / sin(pi / (n + 1) * (power - i + j))
             + 1 / sin(pi / (n + 1) * (power - i - 1)))

    The sin(pi * power) term is zero whenever "power" is an integer. However,
    the 1 / sin(x) terms (cosecants) occasionally (when their arguments are
    multiples of pi) cancel out this value. The limit as the argument approaches
    an integer value gives the "correct" result, but computing these separately
    gives 0 * inf = NaN. Instead, there is a special case for near-integer
    values.

    Args:
      powers: A floating point Tensor of powers to raise the transition matrix
        to.
    Returns:
      A [..., self._num_latent_values - 1, self._num_latent_values - 1] floating
        point Tensor with the transition matrix raised to each power in
        `powers`.

    """
        num_latent_values_float = math_ops.cast(self._num_latent_values,
                                                self.dtype)
        latent_values_per_period = (
            num_latent_values_float /
            math_ops.cast(self._true_periodicity, dtype=self.dtype))
        original_matrix_powers = (math_ops.cast(powers, self.dtype) *
                                  latent_values_per_period)
        global_coeff = (math_ops.sin(original_matrix_powers * numpy.pi) /
                        num_latent_values_float)[..., None, None]
        matrix_dimension_range = array_ops.reshape(
            math_ops.range(self._num_latent_values - 1),
            array_ops.concat([
                array_ops.ones([array_ops.rank(original_matrix_powers)],
                               dtype=dtypes.int32),
                [self._num_latent_values - 1]
            ],
                             axis=0))
        matrix_dimension_range_float = math_ops.cast(matrix_dimension_range,
                                                     self.dtype)
        alternating = math_ops.cast(1 - 2 * (matrix_dimension_range % 2),
                                    self.dtype)
        row_addend = 1. / math_ops.sin(numpy.pi / num_latent_values_float *
                                       (original_matrix_powers[..., None] -
                                        matrix_dimension_range_float - 1))
        column_minus_row = (matrix_dimension_range_float[..., None, :] -
                            matrix_dimension_range_float[..., None])
        full_matrix_addend = (alternating[..., None, :] / math_ops.sin(
            numpy.pi / num_latent_values_float *
            (original_matrix_powers[..., None, None] + column_minus_row)))
        continuous_construction = global_coeff * alternating[..., None] * (
            row_addend[..., None] + full_matrix_addend)
        # For integer powers, the above formula is only correct in the limit,
        # yielding NaNs as written. We defer to the super-class in such cases, which
        # computes integer powers exactly.
        return array_ops.where(
            self._close_to_integer(original_matrix_powers),
            super(ResolutionCycleModel, self).transition_to_powers(
                math_ops.cast(gen_math_ops.round(original_matrix_powers),
                              dtypes.int64)), continuous_construction)
Пример #5
0
 def _close_to_integer(self, value):
     value = math_ops.cast(value, self.dtype)
     return math_ops.less(math_ops.abs(value - gen_math_ops.round(value)),
                          self._near_integer_threshold)
Пример #6
0
  def transition_to_powers(self, powers):
    """Computes TransitionMatrix^power efficiently.

    For an n x n transition matrix we have:

      (TransitionMatrix**power)_{i, j) = (-1) ** i * sin(pi * power) / (n + 1)
          * ((-1) ** j / sin(pi / (n + 1) * (power - i + j))
             + 1 / sin(pi / (n + 1) * (power - i - 1)))

    The sin(pi * power) term is zero whenever "power" is an integer. However,
    the 1 / sin(x) terms (cosecants) occasionally (when their arguments are
    multiples of pi) cancel out this value. The limit as the argument approaches
    an integer value gives the "correct" result, but computing these separately
    gives 0 * inf = NaN. Instead, there is a special case for near-integer
    values.

    Args:
      powers: A floating point Tensor of powers to raise the transition matrix
        to.
    Returns:
      A [..., self._num_latent_values - 1, self._num_latent_values - 1] floating
        point Tensor with the transition matrix raised to each power in
        `powers`.

    """
    num_latent_values_float = math_ops.cast(self._num_latent_values, self.dtype)
    latent_values_per_period = (num_latent_values_float / math_ops.cast(
        self._true_periodicity, dtype=self.dtype))
    original_matrix_powers = (math_ops.cast(powers, self.dtype) *
                              latent_values_per_period)
    global_coeff = (math_ops.sin(original_matrix_powers * numpy.pi) /
                    num_latent_values_float)[..., None, None]
    matrix_dimension_range = array_ops.reshape(
        math_ops.range(self._num_latent_values - 1),
        array_ops.concat(
            [
                array_ops.ones(
                    [array_ops.rank(original_matrix_powers)],
                    dtype=dtypes.int32), [self._num_latent_values - 1]
            ],
            axis=0))
    matrix_dimension_range_float = math_ops.cast(matrix_dimension_range,
                                                 self.dtype)
    alternating = math_ops.cast(1 - 2 * (matrix_dimension_range % 2),
                                self.dtype)
    row_addend = 1. / math_ops.sin(numpy.pi / num_latent_values_float * (
        original_matrix_powers[..., None] - matrix_dimension_range_float - 1))
    column_minus_row = (matrix_dimension_range_float[..., None, :]
                        - matrix_dimension_range_float[..., None])
    full_matrix_addend = (alternating[..., None, :] / math_ops.sin(
        numpy.pi / num_latent_values_float *
        (original_matrix_powers[..., None, None] + column_minus_row)))
    continuous_construction = global_coeff * alternating[..., None] * (
        row_addend[..., None] + full_matrix_addend)
    # For integer powers, the above formula is only correct in the limit,
    # yielding NaNs as written. We defer to the super-class in such cases, which
    # computes integer powers exactly.
    return array_ops.where(
        self._close_to_integer(original_matrix_powers),
        super(ResolutionCycleModel, self).transition_to_powers(
            math_ops.cast(
                gen_math_ops.round(original_matrix_powers), dtypes.int64)),
        continuous_construction)
Пример #7
0
 def _close_to_integer(self, value):
   value = math_ops.cast(value, self.dtype)
   return math_ops.less(
       math_ops.abs(value - gen_math_ops.round(value)),
       self._near_integer_threshold)