Example #1
0
def summary_writer_function(name, tensor, function, family=None):
  """Helper function to write summaries.

  Args:
    name: name of the summary
    tensor: main tensor to form the summary
    function: function taking a tag and a scope which writes the summary
    family: optional, the summary's family

  Returns:
    The result of writing the summary.
  """
  name_scope = ops.get_name_scope()
  if name_scope:
    # Add a slash to allow reentering the name scope.
    name_scope += "/"
  def record():
    with ops.name_scope(name_scope), summary_op_util.summary_scope(
        name, family, values=[tensor]) as (tag, scope):
      with ops.control_dependencies([function(tag, scope)]):
        return constant_op.constant(True)

  if context.context().summary_writer_resource is None:
    return control_flow_ops.no_op()
  with ops.device("cpu:0"):
    op = smart_cond.smart_cond(
        should_record_summaries(), record, _nothing, name="")
    if not context.executing_eagerly():
      ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
  return op
Example #2
0
def categorical_crossentropy(y_true,
                             y_pred,
                             from_logits=False,
                             label_smoothing=0):
  """Computes the categorical crossentropy loss.

  Args:
    y_true: tensor of true targets.
    y_pred: tensor of predicted targets.
    from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
      we assume that `y_pred` encodes a probability distribution.
    label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.

  Returns:
    Categorical crossentropy loss value.
  """
  y_pred = ops.convert_to_tensor(y_pred)
  y_true = math_ops.cast(y_true, y_pred.dtype)
  label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())

  def _smooth_labels():
    num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
    return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)

  y_true = smart_cond.smart_cond(label_smoothing,
                                 _smooth_labels, lambda: y_true)
  return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
Example #3
0
def write(tag, tensor, step=None, metadata=None, name=None):
  """Writes a generic summary to the default SummaryWriter if one exists.

  This exists primarily to support the definition of type-specific summary ops
  like scalar() and image(), and is not intended for direct use unless defining
  a new type-specific summary op.

  Args:
    tag: string tag used to identify the summary (e.g. in TensorBoard), usually
      generated with `tf.summary.summary_scope`
    tensor: the Tensor holding the summary data to write
    step: Explicit `int64`-castable monotonic step value for this summary. If
      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
      not be None.
    metadata: Optional SummaryMetadata, as a proto or serialized bytes
    name: Optional string name for this op.

  Returns:
    True on success, or false if no summary was written because no default
    summary writer was available.

  Raises:
    ValueError: if a default writer exists, but no step was provided and
      `tf.summary.experimental.get_step()` is None.
  """
  with ops.name_scope(name, "write_summary") as scope:
    if context.context().summary_writer is None:
      return constant_op.constant(False)
    if step is None:
      step = get_step()
      if step is None:
        raise ValueError("No step set via 'step' argument or "
                         "tf.summary.experimental.set_step()")
    if metadata is None:
      serialized_metadata = b""
    elif hasattr(metadata, "SerializeToString"):
      serialized_metadata = metadata.SerializeToString()
    else:
      serialized_metadata = metadata

    def record():
      """Record the actual summary and return True."""
      # Note the identity to move the tensor to the CPU.
      with ops.device("cpu:0"):
        write_summary_op = gen_summary_ops.write_summary(
            context.context().summary_writer._resource,  # pylint: disable=protected-access
            step,
            array_ops.identity(tensor),
            tag,
            serialized_metadata,
            name=scope)
        with ops.control_dependencies([write_summary_op]):
          return constant_op.constant(True)

    with ops.device("cpu:0"):
      op = smart_cond.smart_cond(
          _should_record_summaries_v2(), record, _nothing, name="summary_cond")
      if not context.executing_eagerly():
        ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
      return op
Example #4
0
  def _contraction():
    """Performs a contraction."""
    contracted = face_centroid - contraction * (face_centroid -
                                                simplex[worst_index])
    objective_at_contracted = objective_function(contracted)
    is_contracted_acceptable = objective_at_contracted <= worst_objective_value
    def _accept_contraction():
      next_simplex = _replace_at_index(simplex, worst_index, contracted)
      objective_at_next_simplex = _replace_at_index(
          objective_values,
          worst_index,
          objective_at_contracted)
      return (
          False,
          next_simplex,
          objective_at_next_simplex,
          1
      )

    def _reject_contraction():
      return _shrink_towards_best(objective_function, simplex, best_index,
                                  shrinkage, batch_evaluate_objective)

    return smart_cond.smart_cond(is_contracted_acceptable,
                                 _accept_contraction,
                                 _reject_contraction)
 def testUnknown(self):
   with ops.Graph().as_default():
     with session.Session():
       x = array_ops.placeholder(dtype=dtypes.int32)
       y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
                                 lambda: constant_op.constant(2))
       self.assertEqual(y.eval(feed_dict={x: 1}), 1)
       self.assertEqual(y.eval(feed_dict={x: -1}), 2)
 def testPlaceholderWithDefault(self):
   with ops.Graph().as_default():
     with session.Session():
       x = array_ops.placeholder_with_default(1, shape=())
       y = smart_cond.smart_cond(x > 0, lambda: constant_op.constant(1),
                                 lambda: constant_op.constant(2))
       self.assertEqual(y.eval(), 1)
       self.assertEqual(y.eval(feed_dict={x: -1}), 2)
 def testSmartCondTrue(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(2)
       y = constant_op.constant(5)
       z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16),
                                 lambda: math_ops.multiply(y, 5))
       self.assertEqual(z.eval(), 32)
 def testSmartCondFalse(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(4)
       y = constant_op.constant(3)
       z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16),
                                 lambda: math_ops.multiply(y, 3))
       self.assertEqual(z.eval(), 9)
Example #9
0
def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):

  def _smooth_labels():
    return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

  y_true = smart_cond.smart_cond(label_smoothing,
                                 _smooth_labels, lambda: y_true)
  return K.mean(
      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
Example #10
0
 def testEval(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(1)
       y = constant_op.constant(2)
       # x * y > 0 can be evaluated at graph construction time, so the false
       # branch shouldn't be evaluated at all.
       z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
                                 raise_exception)
       self.assertEqual(z.eval(feed_dict={x: 1}), 1)
Example #11
0
def binary_crossentropy(y_true, y_pred, from_logits=False, label_smoothing=0):  # pylint: disable=missing-docstring
  y_pred = ops.convert_to_tensor(y_pred)
  y_true = math_ops.cast(y_true, y_pred.dtype)
  label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())

  def _smooth_labels():
    return y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing

  y_true = smart_cond.smart_cond(label_smoothing,
                                 _smooth_labels, lambda: y_true)
  return K.mean(
      K.binary_crossentropy(y_true, y_pred, from_logits=from_logits), axis=-1)
Example #12
0
  def testEval(self):
    # Constant expression evaluation only works with the C API enabled.
    if not ops._USE_C_API: return

    with ops.Graph().as_default():
      with session.Session():
        x = constant_op.constant(1)
        y = constant_op.constant(2)
        # x * y > 0 can be evaluated at graph construction time, so the false
        # branch shouldn't be evaluated at all.
        z = smart_cond.smart_cond(x * y > 0, lambda: constant_op.constant(1),
                                  raise_exception)
        self.assertEqual(z.eval(feed_dict={x: 1}), 1)
Example #13
0
def _maybe_convert_labels(y_true):
  """Converts binary labels into -1/1."""
  are_zeros = math_ops.equal(y_true, 0)
  are_ones = math_ops.equal(y_true, 1)
  is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones))

  def _convert_binary_labels():
    # Convert the binary labels to -1 or 1.
    return 2. * y_true - 1.

  updated_y_true = smart_cond.smart_cond(is_binary,
                                         _convert_binary_labels, lambda: y_true)
  return updated_y_true
Example #14
0
  def result(self, write_summary=True):
    """Returns the result of the Metric.

    Args:
      write_summary: bool indicating whether to feed the result to the summary
        before returning.
    Returns:
      aggregated metric as float.
    Raises:
      ValueError: if the optional argument is not bool
    """
    # Convert the boolean to tensor for tf.cond, if it is not.
    if not isinstance(write_summary, ops.Tensor):
      write_summary = ops.convert_to_tensor(write_summary)
    t = self.numer / self.denom
    def write_summary_f():
      summary_ops.scalar(name=self.name, tensor=t)
      return t
    smart_cond.smart_cond(write_summary,
                          write_summary_f,
                          lambda: t,
                          name="")
    return t
Example #15
0
def write_raw_pb(tensor, step=None, name=None):
  """Writes a summary using raw `tf.compat.v1.Summary` protocol buffers.

  Experimental: this exists to support the usage of V1-style manual summary
  writing (via the construction of a `tf.compat.v1.Summary` protocol buffer)
  with the V2 summary writing API.

  Args:
    tensor: the string Tensor holding one or more serialized `Summary` protobufs
    step: Explicit `int64`-castable monotonic step value for this summary. If
      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
      not be None.
    name: Optional string name for this op.

  Returns:
    True on success, or false if no summary was written because no default
    summary writer was available.

  Raises:
    ValueError: if a default writer exists, but no step was provided and
      `tf.summary.experimental.get_step()` is None.
  """
  with ops.name_scope(name, "write_raw_pb") as scope:
    if context.context().summary_writer is None:
      return constant_op.constant(False)
    if step is None:
      step = get_step()
      if step is None:
        raise ValueError("No step set via 'step' argument or "
                         "tf.summary.experimental.set_step()")

    def record():
      """Record the actual summary and return True."""
      # Note the identity to move the tensor to the CPU.
      with ops.device("cpu:0"):
        raw_summary_op = gen_summary_ops.write_raw_proto_summary(
            context.context().summary_writer._resource,  # pylint: disable=protected-access
            step,
            array_ops.identity(tensor),
            name=scope)
        with ops.control_dependencies([raw_summary_op]):
          return constant_op.constant(True)

    with ops.device("cpu:0"):
      op = smart_cond.smart_cond(
          _should_record_summaries_v2(), record, _nothing, name="summary_cond")
      if not context.executing_eagerly():
        ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
      return op
Example #16
0
 def _expand_and_maybe_replace():
   """Performs the expansion step."""
   expanded = face_centroid + expansion * (reflected - face_centroid)
   expanded_objective_value = objective_function(expanded)
   expanded_is_better = (expanded_objective_value <
                         objective_at_reflected)
   accept_expanded_fn = lambda: (expanded, expanded_objective_value)
   accept_reflected_fn = lambda: (reflected, objective_at_reflected)
   next_pt, next_objective_value = smart_cond.smart_cond(
       expanded_is_better, accept_expanded_fn, accept_reflected_fn)
   next_simplex = _replace_at_index(simplex, worst_index, next_pt)
   next_objective_at_simplex = _replace_at_index(objective_values,
                                                 worst_index,
                                                 next_objective_value)
   return False, next_simplex, next_objective_at_simplex, 1
Example #17
0
 def call(self, x, training=None):
   # We basically want to call this...
   f = functools.partial(self._func, x, **self._arguments)
   # ...but we may also have to pass a Python boolean for `training`.
   if not self._func_wants_training:
     result = f()
   else:
     if training is None:
       training = tf.keras.backend.learning_phase()  # Could be a tensor.
     result = smart_cond.smart_cond(training,
                                    lambda: f(training=True),
                                    lambda: f(training=False))
   # TODO(b/124219898): Polymorphic function should return shaped tensor.
   if hasattr(self, '_output_shape'):
     result.set_shape((x.shape[0],) + self._output_shape)
   return result
Example #18
0
def write(tag, tensor, step, metadata=None, name=None):
  """Writes a generic summary to the default SummaryWriter if one exists.

  This exists primarily to support the definition of type-specific summary ops
  like scalar() and image(), and is not intended for direct use unless defining
  a new type-specific summary op.

  Args:
    tag: string tag used to identify the summary (e.g. in TensorBoard), usually
      generated with `tf.summary.summary_scope`
    tensor: the Tensor holding the summary data to write
    step: `int64`-castable monotic step value for this summary
    metadata: Optional SummaryMetadata, as a proto or serialized bytes
    name: Optional string name for this op.

  Returns:
    True on success, or false if no summary was written because no default
    summary writer was available.
  """
  with ops.name_scope(name, "write_summary") as scope:
    if context.context().summary_writer_resource is None:
      return constant_op.constant(False)
    if metadata is None:
      serialized_metadata = constant_op.constant(b"")
    elif hasattr(metadata, "SerializeToString"):
      serialized_metadata = constant_op.constant(metadata.SerializeToString())
    else:
      serialized_metadata = metadata

    def record():
      """Record the actual summary and return True."""
      # Note the identity to move the tensor to the CPU.
      with ops.device("cpu:0"):
        write_summary_op = gen_summary_ops.write_summary(
            context.context().summary_writer_resource,
            step,
            array_ops.identity(tensor),
            tag,
            serialized_metadata,
            name=scope)
        with ops.control_dependencies([write_summary_op]):
          return constant_op.constant(True)

    return smart_cond.smart_cond(
        _should_record_summaries_v2(), record, _nothing, name="summary_cond")
  def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name):
    grads = [g for g, _ in grads_and_vars]
    loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads)

    def apply_fn():
      # We do not want DistributionStrategy to unwrap any MirroredVariables in
      # grads_and_vars, because even in a replica context, the wrapped optimizer
      # expects mirrored variables. So we wrap grads_and_vars with an
      # _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
      # MirroredVariables.
      wrapped_grads_and_vars = _UnwrapPreventer(grads_and_vars)
      return distribution.extended.call_for_each_replica(
          self._apply_gradients, args=(wrapped_grads_and_vars, name))

    # Note: We must call this cond() in a cross-replica context.
    # DistributionStrategy does not support having a cond in a replica context
    # with a branch that calls `merge_call`, and self._optimizer.apply_gradients
    # calls `merge_call`.
    maybe_apply_op = smart_cond.smart_cond(should_apply_grads,
                                           apply_fn,
                                           control_flow_ops.no_op)
    return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
Example #20
0
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.

  If `pred` is a bool or has a constant value, we return either `true_fn()`
  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.

  Arguments:
    pred: A scalar determining whether to return the result of `true_fn` or
      `false_fn`.
    true_fn: The callable to be performed if pred is true.
    false_fn: The callable to be performed if pred is false.
    name: Optional name prefix when using `tf.cond`.

  Returns:
    Tensors returned by the call to either `true_fn` or `false_fn`.

  Raises:
    TypeError: If `true_fn` or `false_fn` is not callable.
  """
  if isinstance(pred, variables.Variable):
    return control_flow_ops.cond(
        pred, true_fn=true_fn, false_fn=false_fn, name=name)
  return smart_module.smart_cond(
      pred, true_fn=true_fn, false_fn=false_fn, name=name)
Example #21
0
  def call(self, inputs, training=None):
    if self.scale and self.gamma_quantizer:
      quantized_gamma = self.gamma_quantizer_internal(self.gamma)
    else:
      quantized_gamma = self.gamma

    if self.center and self.beta_quantizer:
      quantized_beta = self.beta_quantizer_internal(self.beta)
    else:
      quantized_beta = self.beta

    if self.mean_quantizer:
      quantized_moving_mean = self.mean_quantizer_internal(self.moving_mean)
    else:
      quantized_moving_mean = self.moving_mean

    if self.variance_quantizer:
      quantized_moving_variance = self.variance_quantizer_internal(
          self.moving_variance)
    else:
      quantized_moving_variance = self.moving_variance

    training = self._get_training_value(training)

    # Compute the axes along which to reduce the mean / variance
    input_shape = inputs.shape
    ndims = len(input_shape)
    reduction_axes = [i for i in range(ndims) if i not in self.axis]

    # Broadcasting only necessary for single-axis batch norm where the axis is
    # not the last dimension
    broadcast_shape = [1] * ndims
    broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value
    def _broadcast(v):
      if (v is not None and len(v.shape) != ndims and
          reduction_axes != list(range(ndims - 1))):
        return array_ops.reshape(v, broadcast_shape)
      return v

    scale, offset = _broadcast(quantized_gamma), _broadcast(quantized_beta)

    # Determine a boolean value for `training`: could be True, False, or None.
    training_value = tf_utils.smart_constant_value(training)
    if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
      quantized_mean, quantized_variance = (quantized_moving_mean,
                                            quantized_moving_variance)
    else:
      # Some of the computations here are not necessary when training==False
      # but not a constant. However, this makes the code simpler.
      keep_dims = len(self.axis) > 1
      mean, variance = self._moments(
          math_ops.cast(inputs, self._param_dtype),
          reduction_axes,
          keep_dims=keep_dims)

      moving_mean = self.moving_mean
      moving_variance = self.moving_variance

      mean = tf_utils.smart_cond(
          training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean))
      variance = tf_utils.smart_cond(
          training,
          lambda: variance,
          lambda: ops.convert_to_tensor(moving_variance))

      new_mean, new_variance = mean, variance

      if self.mean_quantizer:
        quantized_mean = self.mean_quantizer_internal(mean)
      else:
        quantized_mean = mean

      if self.variance_quantizer:
        quantized_variance = self.variance_quantizer_internal(variance)
      else:
        quantized_variance = variance

      if self._support_zero_size_input():
        inputs_size = array_ops.size(inputs)
      else:
        inputs_size = None

      def _do_update(var, value):
        """Compute the updates for mean and variance."""
        return self._assign_moving_average(var, value, self.momentum,
                                           inputs_size)

      def mean_update():
        true_branch = lambda: _do_update(self.moving_mean, new_mean)
        false_branch = lambda: self.moving_mean
        return tf_utils.smart_cond(training, true_branch, false_branch)

      def variance_update():
        """Update the moving variance."""
        true_branch = lambda: _do_update(self.moving_variance, new_variance)
        false_branch = lambda: self.moving_variance
        return tf_utils.smart_cond(training, true_branch, false_branch)

      self.add_update(mean_update)
      self.add_update(variance_update)

    quantized_mean = math_ops.cast(quantized_mean, inputs.dtype)
    quantized_variance = math_ops.cast(quantized_variance, inputs.dtype)
    if offset is not None:
      offset = math_ops.cast(offset, inputs.dtype)
    if scale is not None:
      scale = math_ops.cast(scale, inputs.dtype)
    # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
    # math in float16 hurts validation accuracy of popular models like resnet.
    outputs = nn.batch_normalization(inputs,
                                     _broadcast(quantized_mean),
                                     _broadcast(quantized_variance),
                                     offset,
                                     scale,
                                     self.epsilon)
    # If some components of the shape got lost due to adjustments, fix that.
    outputs.set_shape(input_shape)

    return outputs
Example #22
0
 def mean_update():
   true_branch = lambda: _do_update(self.moving_mean, new_mean)
   false_branch = lambda: self.moving_mean
   return tf_utils.smart_cond(training, true_branch, false_branch)
Example #23
0
 def variance_update():
   """Update the moving variance."""
   true_branch = lambda: _do_update(self.moving_variance, new_variance)
   false_branch = lambda: self.moving_variance
   return tf_utils.smart_cond(training, true_branch, false_branch)
Example #24
0
def hager_zhang(value_and_gradients_function,
                initial_step_size=None,
                objective_at_zero=None,
                grad_objective_at_zero=None,
                objective_at_initial_step_size=None,
                grad_objective_at_initial_step_size=None,
                threshold_use_approximate_wolfe_condition=1e-6,
                shrinkage_param=0.66,
                expansion_param=5.0,
                sufficient_decrease_param=0.1,
                curvature_param=0.9,
                name=None):
  """The Hager Zhang line search algorithm.

  Performs an inexact line search based on the algorithm of
  [Hager and Zhang (2006)][2].
  The univariate objective function `value_and_gradients_function` is typically
  generated by projecting
  a multivariate objective function along a search direction. Suppose the
  multivariate function to be minimized is `g(x1,x2, .. xn)`. Let
  (d1, d2, ..., dn) be the direction along which we wish to perform a line
  search. Then the projected univariate function to be used for line search is

  ```None
    f(a) = g(x1 + d1 * a, x2 + d2 * a, ..., xn + dn * a)
  ```

  The directional derivative along (d1, d2, ..., dn) is needed for this
  procedure. This also corresponds to the derivative of the projected function
  `f(a)` with respect to `a`. Note that this derivative must be negative for
  `a = 0` if the direction is a descent direction.

  The usual stopping criteria for the line search is the satisfaction of the
  (weak) Wolfe conditions. For details of the Wolfe conditions, see
  ref. [3]. On a finite precision machine, the exact Wolfe conditions can
  be difficult to satisfy when one is very close to the minimum and as argued
  by [Hager and Zhang (2005)][1], one can only expect the minimum to be
  determined within square root of machine precision. To improve the situation,
  they propose to replace the Wolfe conditions with an approximate version
  depending on the derivative of the function which is applied only when one
  is very close to the minimum. The following algorithm implements this
  enhanced scheme.

  ### Usage:

  Primary use of line search methods is as an internal component of a class of
  optimization algorithms (called line search based methods as opposed to
  trust region methods). Hence, the end user will typically not want to access
  line search directly. In particular, inexact line search should not be
  confused with a univariate minimization method. The stopping criteria of line
  search is the satisfaction of Wolfe conditions and not the discovery of the
  minimum of the function.

  With this caveat in mind, the following example illustrates the standalone
  usage of the line search.

  ```python
    # Define a quadratic target with minimum at 1.3.
    value_and_gradients_function = lambda x: ((x - 1.3) ** 2, 2 * (x-1.3))
    # Set initial step size.
    step_size = tf.constant(0.1)
    ls_result = tfp.optimizer.linesearch.hager_zhang(
        value_and_gradients_function, initial_step_size=step_size)
    # Evaluate the results.
    with tf.Session() as session:
      results = session.run(ls_result)
      # Ensure convergence.
      assert(results.converged)
      # If the line search converged, the left and the right ends of the
      # bracketing interval are identical.
      assert(results.left_pt == result.right_pt)
      # Print the number of evaluations and the final step size.
      print ("Final Step Size: %f, Evaluation: %d" % (results.left_pt,
                                                      results.func_evals))
  ```

  ### References:
  [1]: William Hager, Hongchao Zhang. A new conjugate gradient method with
    guaranteed descent and an efficient line search. SIAM J. Optim., Vol 16. 1,
    pp. 170-172. 2005.
    https://www.math.lsu.edu/~hozhang/papers/cg_descent.pdf
  [2]: William Hager, Hongchao Zhang. Algorithm 851: CG_DESCENT, a conjugate
    gradient method with guaranteed descent. ACM Transactions on Mathematical
    Software, Vol 32., 1, pp. 113-137. 2006.
    http://users.clas.ufl.edu/hager/papers/CG/cg_compare.pdf
  [3]: Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series in
    Operations Research. pp 33-36. 2006

  Args:
    value_and_gradients_function: A Python callable that accepts a real scalar
      tensor and returns a tuple of scalar tensors of real dtype containing
      the value of the function and its derivative at that point.
      In usual optimization application, this function would be generated by
      projecting the multivariate objective function along some specific
      direction. The direction is determined by some other procedure but should
      be a descent direction (i.e. the derivative of the projected univariate
      function must be negative at 0.).
    initial_step_size: (Optional) Scalar positive `Tensor` of real dtype. The
      initial value to try to bracket the minimum. Default is `1.` as a float32.
      Note that this point need not necessarily bracket the minimum for the line
      search to work correctly but the supplied value must be greater than
      0. A good initial value will make the search converge faster.
    objective_at_zero: (Optional) Scalar `Tensor` of real dtype. If supplied,
      the value of the function at `0.`. If not supplied, it will be computed.
    grad_objective_at_zero: (Optional) Scalar `Tensor` of real dtype. If
      supplied, the derivative of the  function at `0.`. If not supplied, it
      will be computed.
    objective_at_initial_step_size: (Optional) Scalar `Tensor` of real dtype.
      If supplied, the value of the function at `initial_step_size`.
      If not supplied, it will be computed.
    grad_objective_at_initial_step_size: (Optional) Scalar `Tensor` of real
      dtype. If supplied, the derivative of the  function at
      `initial_step_size`. If not supplied, it will be computed.
    threshold_use_approximate_wolfe_condition: Scalar positive `Tensor`
      of real dtype. Corresponds to the parameter 'epsilon' in
      [Hager and Zhang (2006)][2]. Used to estimate the
      threshold at which the line search switches to approximate Wolfe
      conditions.
    shrinkage_param: Scalar positive Tensor of real dtype. Must be less than
      `1.`. Corresponds to the parameter `gamma` in
      [Hager and Zhang (2006)][2].
      If the secant**2 step does not shrink the bracketing interval by this
      proportion, a bisection step is performed to reduce the interval width.
    expansion_param: Scalar positive `Tensor` of real dtype. Must be greater
      than `1.`. Used to expand the initial interval in case it does not bracket
      a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2].
    sufficient_decrease_param: Positive scalar `Tensor` of real dtype.
      Bounded above by the curvature param. Corresponds to `delta` in the
      terminology of [Hager and Zhang (2006)][2].
    curvature_param: Positive scalar `Tensor` of real dtype. Bounded above
      by `1.`. Corresponds to 'sigma' in the terminology of
      [Hager and Zhang (2006)][2].
    name: (Optional) Python str. The name prefixed to the ops created by this
      function. If not supplied, the default name 'hager_zhang' is used.

  Returns:
    results: A namedtuple containing the following attributes.
      converged: Boolean scalar `Tensor`. Whether a point satisfying
        Wolfe/Approx wolfe was found.
      func_evals: Scalar int32 `Tensor`. Number of function evaluations made.
      left_pt: Scalar `Tensor` of same dtype as `initial_step_size`. The
        left end point of the final bracketing interval. If converged is True,
        it is equal to `right_pt`. Otherwise, it corresponds to the last
        interval computed.
      objective_at_left_pt: Scalar `Tensor` of same dtype as
        `objective_at_initial_step_size`. The function value at the left
        end point. If converged is True, it is equal to `objective_at_right_pt`.
        Otherwise, it corresponds to the last interval computed.
      grad_objective_at_left_pt: Scalar `Tensor` of same dtype as
        `grad_objective_at_initial_step_size`. The derivative of the function
        at the left end point. If converged is True,
        it is equal to `grad_objective_at_right_pt`. Otherwise it
        corresponds to the last interval computed.
      right_pt: Scalar `Tensor` of same dtype as `initial_step_size`.
        The right end point of the final bracketing interval.
        If converged is True, it is equal to 'step'. Otherwise,
        it corresponds to the last interval computed.
      objective_at_right_pt: Scalar `Tensor` of same dtype as
        `objective_at_initial_step_size`.
        The function value at the right end point. If converged is True, it
        is equal to fn_step. Otherwise, it corresponds to the last
        interval computed.
      grad_objective_at_right_pt'  Scalar `Tensor` of same dtype as
        `grad_objective_at_initial_step_size`.
        The derivative of the function at the right end point.
        If converged is True, it is equal to the dfn_step.
        Otherwise it corresponds to the last interval computed.
  """
  with tf.name_scope(name, 'hager_zhang',
                     [initial_step_size,
                      objective_at_zero,
                      grad_objective_at_zero,
                      objective_at_initial_step_size,
                      grad_objective_at_initial_step_size,
                      threshold_use_approximate_wolfe_condition,
                      shrinkage_param,
                      expansion_param,
                      sufficient_decrease_param,
                      curvature_param]):

    val_0, val_c, f_lim, prepare_evals = _prepare_args(
        value_and_gradients_function,
        initial_step_size,
        objective_at_initial_step_size,
        grad_objective_at_initial_step_size,
        objective_at_zero,
        grad_objective_at_zero,
        threshold_use_approximate_wolfe_condition)

    # Check if the initial step size already satisfies the Wolfe conditions.
    # If it does, there is no further work.
    already_converged = _satisfies_wolfe(val_0, val_c, f_lim,
                                         sufficient_decrease_param,
                                         curvature_param)

    def _cond(converged, *ignored_args):  # pylint:disable=unused-argument
      """Loops until convergence is reached."""
      return ~converged

    def _update_with_mid(current_evals, left, right):
      """Corresponds to step L3 in [Hager and Zhang (2006)][2]."""
      mid_pt = (left.x + right.x) / 2
      f_mid, df_mid = value_and_gradients_function(mid_pt)
      mid = _FnDFn(x=mid_pt, f=f_mid, df=df_mid)
      update_evals, next_left, next_right = _update(
          value_and_gradients_function,
          left,
          right,
          mid,
          f_lim)
      return (False,
              current_evals + update_evals + 1,
              next_left,
              next_right)

    def _body(converged, evals, left, right):
      converged, secant2_evals, next_left, next_right = _secant2(
          value_and_gradients_function, val_0, left, right, f_lim,
          sufficient_decrease_param=sufficient_decrease_param,
          curvature_param=curvature_param)

      evals += secant2_evals
      # If converged, then do no further processing.
      return smart_cond.smart_case(
          [(converged,
            lambda: (True, evals, next_left, next_right)),
           (next_right.x - next_left.x > shrinkage_param * (right.x - left.x),
            lambda: _update_with_mid(evals, next_left, next_right))],
          default=lambda: (False, evals, next_left, next_right))

    def do_line_search():
      _, bracket_evals, left, right = _bracket(
          value_and_gradients_function,
          val_0,
          val_c,
          f_lim,
          expansion_param=expansion_param)
      return tf.while_loop(_cond,
                           _body,
                           (False,
                            bracket_evals + prepare_evals,
                            left,
                            right),
                           parallel_iterations=1)

    converged, func_evals, left, right = smart_cond.smart_cond(
        already_converged,
        lambda: (already_converged, prepare_evals, val_c, val_c),
        do_line_search)

    return HagerZhangLineSearchResult(
        converged=converged,
        func_evals=func_evals,
        left_pt=left.x,
        objective_at_left_pt=left.f,
        grad_objective_at_left_pt=left.df,
        right_pt=right.x,
        objective_at_right_pt=right.f,
        grad_objective_at_right_pt=right.df
    )
Example #25
0
def hager_zhang(value_and_gradients_function,
                initial_step_size=None,
                objective_at_zero=None,
                grad_objective_at_zero=None,
                objective_at_initial_step_size=None,
                grad_objective_at_initial_step_size=None,
                threshold_use_approximate_wolfe_condition=1e-6,
                shrinkage_param=0.66,
                expansion_param=5.0,
                sufficient_decrease_param=0.1,
                curvature_param=0.9,
                name=None):
    """The Hager Zhang line search algorithm.

  Performs an inexact line search based on the algorithm of
  [Hager and Zhang (2006)][2].
  The univariate objective function `value_and_gradients_function` is typically
  generated by projecting
  a multivariate objective function along a search direction. Suppose the
  multivariate function to be minimized is `g(x1,x2, .. xn)`. Let
  (d1, d2, ..., dn) be the direction along which we wish to perform a line
  search. Then the projected univariate function to be used for line search is

  ```None
    f(a) = g(x1 + d1 * a, x2 + d2 * a, ..., xn + dn * a)
  ```

  The directional derivative along (d1, d2, ..., dn) is needed for this
  procedure. This also corresponds to the derivative of the projected function
  `f(a)` with respect to `a`. Note that this derivative must be negative for
  `a = 0` if the direction is a descent direction.

  The usual stopping criteria for the line search is the satisfaction of the
  (weak) Wolfe conditions. For details of the Wolfe conditions, see
  ref. [3]. On a finite precision machine, the exact Wolfe conditions can
  be difficult to satisfy when one is very close to the minimum and as argued
  by [Hager and Zhang (2005)][1], one can only expect the minimum to be
  determined within square root of machine precision. To improve the situation,
  they propose to replace the Wolfe conditions with an approximate version
  depending on the derivative of the function which is applied only when one
  is very close to the minimum. The following algorithm implements this
  enhanced scheme.

  ### Usage:

  Primary use of line search methods is as an internal component of a class of
  optimization algorithms (called line search based methods as opposed to
  trust region methods). Hence, the end user will typically not want to access
  line search directly. In particular, inexact line search should not be
  confused with a univariate minimization method. The stopping criteria of line
  search is the satisfaction of Wolfe conditions and not the discovery of the
  minimum of the function.

  With this caveat in mind, the following example illustrates the standalone
  usage of the line search.

  ```python
    # Define a quadratic target with minimum at 1.3.
    value_and_gradients_function = lambda x: ((x - 1.3) ** 2, 2 * (x-1.3))
    # Set initial step size.
    step_size = tf.constant(0.1)
    ls_result = tfp.optimizer.linesearch.hager_zhang(
        value_and_gradients_function, initial_step_size=step_size)
    # Evaluate the results.
    with tf.Session() as session:
      results = session.run(ls_result)
      # Ensure convergence.
      assert(results.converged)
      # If the line search converged, the left and the right ends of the
      # bracketing interval are identical.
      assert(results.left_pt == result.right_pt)
      # Print the number of evaluations and the final step size.
      print ("Final Step Size: %f, Evaluation: %d" % (results.left_pt,
                                                      results.func_evals))
  ```

  ### References:
  [1]: William Hager, Hongchao Zhang. A new conjugate gradient method with
    guaranteed descent and an efficient line search. SIAM J. Optim., Vol 16. 1,
    pp. 170-172. 2005.
    https://www.math.lsu.edu/~hozhang/papers/cg_descent.pdf
  [2]: William Hager, Hongchao Zhang. Algorithm 851: CG_DESCENT, a conjugate
    gradient method with guaranteed descent. ACM Transactions on Mathematical
    Software, Vol 32., 1, pp. 113-137. 2006.
    http://users.clas.ufl.edu/hager/papers/CG/cg_compare.pdf
  [3]: Jorge Nocedal, Stephen Wright. Numerical Optimization. Springer Series in
    Operations Research. pp 33-36. 2006

  Args:
    value_and_gradients_function: A Python callable that accepts a real scalar
      tensor and returns a tuple of scalar tensors of real dtype containing
      the value of the function and its derivative at that point.
      In usual optimization application, this function would be generated by
      projecting the multivariate objective function along some specific
      direction. The direction is determined by some other procedure but should
      be a descent direction (i.e. the derivative of the projected univariate
      function must be negative at 0.).
    initial_step_size: (Optional) Scalar positive `Tensor` of real dtype. The
      initial value to try to bracket the minimum. Default is `1.` as a float32.
      Note that this point need not necessarily bracket the minimum for the line
      search to work correctly but the supplied value must be greater than
      0. A good initial value will make the search converge faster.
    objective_at_zero: (Optional) Scalar `Tensor` of real dtype. If supplied,
      the value of the function at `0.`. If not supplied, it will be computed.
    grad_objective_at_zero: (Optional) Scalar `Tensor` of real dtype. If
      supplied, the derivative of the  function at `0.`. If not supplied, it
      will be computed.
    objective_at_initial_step_size: (Optional) Scalar `Tensor` of real dtype.
      If supplied, the value of the function at `initial_step_size`.
      If not supplied, it will be computed.
    grad_objective_at_initial_step_size: (Optional) Scalar `Tensor` of real
      dtype. If supplied, the derivative of the  function at
      `initial_step_size`. If not supplied, it will be computed.
    threshold_use_approximate_wolfe_condition: Scalar positive `Tensor`
      of real dtype. Corresponds to the parameter 'epsilon' in
      [Hager and Zhang (2006)][2]. Used to estimate the
      threshold at which the line search switches to approximate Wolfe
      conditions.
    shrinkage_param: Scalar positive Tensor of real dtype. Must be less than
      `1.`. Corresponds to the parameter `gamma` in
      [Hager and Zhang (2006)][2].
      If the secant**2 step does not shrink the bracketing interval by this
      proportion, a bisection step is performed to reduce the interval width.
    expansion_param: Scalar positive `Tensor` of real dtype. Must be greater
      than `1.`. Used to expand the initial interval in case it does not bracket
      a minimum. Corresponds to `rho` in [Hager and Zhang (2006)][2].
    sufficient_decrease_param: Positive scalar `Tensor` of real dtype.
      Bounded above by the curvature param. Corresponds to `delta` in the
      terminology of [Hager and Zhang (2006)][2].
    curvature_param: Positive scalar `Tensor` of real dtype. Bounded above
      by `1.`. Corresponds to 'sigma' in the terminology of
      [Hager and Zhang (2006)][2].
    name: (Optional) Python str. The name prefixed to the ops created by this
      function. If not supplied, the default name 'hager_zhang' is used.

  Returns:
    results: A namedtuple containing the following attributes.
      converged: Boolean scalar `Tensor`. Whether a point satisfying
        Wolfe/Approx wolfe was found.
      func_evals: Scalar int32 `Tensor`. Number of function evaluations made.
      left_pt: Scalar `Tensor` of same dtype as `initial_step_size`. The
        left end point of the final bracketing interval. If converged is True,
        it is equal to `right_pt`. Otherwise, it corresponds to the last
        interval computed.
      objective_at_left_pt: Scalar `Tensor` of same dtype as
        `objective_at_initial_step_size`. The function value at the left
        end point. If converged is True, it is equal to `objective_at_right_pt`.
        Otherwise, it corresponds to the last interval computed.
      grad_objective_at_left_pt: Scalar `Tensor` of same dtype as
        `grad_objective_at_initial_step_size`. The derivative of the function
        at the left end point. If converged is True,
        it is equal to `grad_objective_at_right_pt`. Otherwise it
        corresponds to the last interval computed.
      right_pt: Scalar `Tensor` of same dtype as `initial_step_size`.
        The right end point of the final bracketing interval.
        If converged is True, it is equal to 'step'. Otherwise,
        it corresponds to the last interval computed.
      objective_at_right_pt: Scalar `Tensor` of same dtype as
        `objective_at_initial_step_size`.
        The function value at the right end point. If converged is True, it
        is equal to fn_step. Otherwise, it corresponds to the last
        interval computed.
      grad_objective_at_right_pt'  Scalar `Tensor` of same dtype as
        `grad_objective_at_initial_step_size`.
        The derivative of the function at the right end point.
        If converged is True, it is equal to the dfn_step.
        Otherwise it corresponds to the last interval computed.
  """
    with tf.name_scope(name, 'hager_zhang', [
            initial_step_size, objective_at_zero, grad_objective_at_zero,
            objective_at_initial_step_size,
            grad_objective_at_initial_step_size,
            threshold_use_approximate_wolfe_condition, shrinkage_param,
            expansion_param, sufficient_decrease_param, curvature_param
    ]):

        val_0, val_c, f_lim, prepare_evals = _prepare_args(
            value_and_gradients_function, initial_step_size,
            objective_at_initial_step_size,
            grad_objective_at_initial_step_size, objective_at_zero,
            grad_objective_at_zero, threshold_use_approximate_wolfe_condition)

        # Check if the initial step size already satisfies the Wolfe conditions.
        # If it does, there is no further work.
        already_converged = _satisfies_wolfe(val_0, val_c, f_lim,
                                             sufficient_decrease_param,
                                             curvature_param)

        def _cond(converged, *ignored_args):  # pylint:disable=unused-argument
            """Loops until convergence is reached."""
            return ~converged

        def _update_with_mid(current_evals, left, right):
            """Corresponds to step L3 in [Hager and Zhang (2006)][2]."""
            mid_pt = (left.x + right.x) / 2
            f_mid, df_mid = value_and_gradients_function(mid_pt)
            mid = _FnDFn(x=mid_pt, f=f_mid, df=df_mid)
            update_evals, next_left, next_right = _update(
                value_and_gradients_function, left, right, mid, f_lim)
            return (False, current_evals + update_evals + 1, next_left,
                    next_right)

        def _body(converged, evals, left, right):
            converged, secant2_evals, next_left, next_right = _secant2(
                value_and_gradients_function,
                val_0,
                left,
                right,
                f_lim,
                sufficient_decrease_param=sufficient_decrease_param,
                curvature_param=curvature_param)

            evals += secant2_evals
            # If converged, then do no further processing.
            return smart_cond.smart_case(
                [(converged, lambda: (True, evals, next_left, next_right)),
                 (next_right.x - next_left.x > shrinkage_param *
                  (right.x - left.x),
                  lambda: _update_with_mid(evals, next_left, next_right))],
                default=lambda: (False, evals, next_left, next_right))

        def do_line_search():
            _, bracket_evals, left, right = _bracket(
                value_and_gradients_function,
                val_0,
                val_c,
                f_lim,
                expansion_param=expansion_param)
            return tf.while_loop(
                _cond,
                _body, (False, bracket_evals + prepare_evals, left, right),
                parallel_iterations=1)

        converged, func_evals, left, right = smart_cond.smart_cond(
            already_converged, lambda:
            (already_converged, prepare_evals, val_c, val_c), do_line_search)

        return HagerZhangLineSearchResult(converged=converged,
                                          func_evals=func_evals,
                                          left_pt=left.x,
                                          objective_at_left_pt=left.f,
                                          grad_objective_at_left_pt=left.df,
                                          right_pt=right.x,
                                          objective_at_right_pt=right.f,
                                          grad_objective_at_right_pt=right.df)
 def testMissingArg2(self):
     with ops.Graph().as_default():
         with session.Session():
             x = constant_op.constant(1)
             with self.assertRaises(TypeError):
                 smart_cond.smart_cond(True, lambda: x)
Example #27
0
 def x_permuted():
     return array_ops.transpose(x,
                                perm=smart_cond.smart_cond(
                                    source_idx < dest_idx,
                                    move_right_permutation,
                                    move_left_permutation))
Example #28
0
def move_dimension(x, source_idx, dest_idx):
    """Move a single tensor dimension within its shape.

  This is a special case of `tf.transpose()`, which applies
  arbitrary permutations to tensor dimensions.

  Args:
    x: Tensor of rank `ndims`.
    source_idx: Integer index into `x.shape` (negative indexing is
      supported).
    dest_idx: Integer index into `x.shape` (negative indexing is
      supported).

  Returns:
    x_perm: Tensor of rank `ndims`, in which the dimension at original
     index `source_idx` has been moved to new index `dest_idx`, with
     all other dimensions retained in their original order.

  Example:

  ```python
  x = tf.compat.v1.placeholder(shape=[200, 30, 4, 1, 6])
  x_perm = _move_dimension(x, 1, 1) # no-op
  x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6]
  x_perm = _move_dimension(x, 0, -2) # equivalent to previous
  x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1]
  ```
  """
    ndims = util.prefer_static_rank(x)
    if isinstance(source_idx, int):
        dtype = dtypes.int32
    else:
        dtype = dtypes.as_dtype(source_idx.dtype)

    # Handle negative indexing. Since ndims might be dynamic, this makes
    # source_idx and dest_idx also possibly dynamic.
    if source_idx < 0:
        source_idx = ndims + source_idx
    if dest_idx < 0:
        dest_idx = ndims + dest_idx

    # Construct the appropriate permutation of dimensions, depending
    # whether the source is before or after the destination.
    def move_left_permutation():
        return util.prefer_static_value(
            array_ops.concat([
                math_ops.range(0, dest_idx, dtype=dtype), [source_idx],
                math_ops.range(dest_idx, source_idx, dtype=dtype),
                math_ops.range(source_idx + 1, ndims, dtype=dtype)
            ],
                             axis=0))

    def move_right_permutation():
        return util.prefer_static_value(
            array_ops.concat([
                math_ops.range(0, source_idx, dtype=dtype),
                math_ops.range(source_idx + 1, dest_idx + 1, dtype=dtype),
                [source_idx],
                math_ops.range(dest_idx + 1, ndims, dtype=dtype)
            ],
                             axis=0))

    def x_permuted():
        return array_ops.transpose(x,
                                   perm=smart_cond.smart_cond(
                                       source_idx < dest_idx,
                                       move_right_permutation,
                                       move_left_permutation))

    # One final conditional to handle the special case where source
    # and destination indices are equal.
    return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx),
                                 lambda: x, x_permuted)
Example #29
0
def move_dimension(x, source_idx, dest_idx):
  """Move a single tensor dimension within its shape.

  This is a special case of `tf.transpose()`, which applies
  arbitrary permutations to tensor dimensions.

  Args:
    x: Tensor of rank `ndims`.
    source_idx: Integer index into `x.shape` (negative indexing is
      supported).
    dest_idx: Integer index into `x.shape` (negative indexing is
      supported).

  Returns:
    x_perm: Tensor of rank `ndims`, in which the dimension at original
     index `source_idx` has been moved to new index `dest_idx`, with
     all other dimensions retained in their original order.

  Example:

  ```python
  x = tf.placeholder(shape=[200, 30, 4, 1, 6])
  x_perm = _move_dimension(x, 1, 1) # no-op
  x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6]
  x_perm = _move_dimension(x, 0, -2) # equivalent to previous
  x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1]
  ```
  """
  ndims = util.prefer_static_rank(x)
  if isinstance(source_idx, int):
    dtype = tf.int32
  else:
    dtype = tf.as_dtype(source_idx.dtype)

  # Handle negative indexing. Since ndims might be dynamic, this makes
  # source_idx and dest_idx also possibly dynamic.
  if source_idx < 0:
    source_idx = ndims + source_idx
  if dest_idx < 0:
    dest_idx = ndims + dest_idx

  # Construct the appropriate permutation of dimensions, depending
  # whether the source is before or after the destination.
  def move_left_permutation():
    return util.prefer_static_value(
        tf.concat(
            [
                tf.range(0, dest_idx, dtype=dtype), [source_idx],
                tf.range(dest_idx, source_idx, dtype=dtype),
                tf.range(source_idx + 1, ndims, dtype=dtype)
            ],
            axis=0))

  def move_right_permutation():
    return util.prefer_static_value(
        tf.concat(
            [
                tf.range(0, source_idx, dtype=dtype),
                tf.range(source_idx + 1, dest_idx + 1, dtype=dtype),
                [source_idx],
                tf.range(dest_idx + 1, ndims, dtype=dtype)
            ],
            axis=0))

  def x_permuted():
    return tf.transpose(
        x,
        perm=smart_cond.smart_cond(source_idx < dest_idx,
                                   move_right_permutation,
                                   move_left_permutation))

  # One final conditional to handle the special case where source
  # and destination indices are equal.
  return smart_cond.smart_cond(
      tf.equal(source_idx, dest_idx), lambda: x, x_permuted)
Example #30
0
def fit_one_step(model_matrix,
                 response,
                 model,
                 model_coefficients_start=None,
                 predicted_linear_response_start=None,
                 l2_regularizer=None,
                 dispersion=None,
                 offset=None,
                 learning_rate=None,
                 fast_unsafe_numerics=True,
                 name=None):
    """Runs one step of Fisher scoring.

  Args:
    model_matrix: (Batch of) `float`-like, matrix-shaped `Tensor` where each row
      represents a sample's features.
    response: (Batch of) vector-shaped `Tensor` where each element represents a
      sample's observed response (to the corresponding row of features). Must
      have same `dtype` as `model_matrix`.
    model: `tfp.glm.ExponentialFamily`-like instance used to construct the
      negative log-likelihood loss, gradient, and expected Hessian (i.e., the
      Fisher information matrix).
    model_coefficients_start: Optional (batch of) vector-shaped `Tensor`
      representing the initial model coefficients, one for each column in
      `model_matrix`. Must have same `dtype` as `model_matrix`.
      Default value: Zeros.
    predicted_linear_response_start: Optional `Tensor` with `shape`, `dtype`
      matching `response`; represents `offset` shifted initial linear
      predictions based on `model_coefficients_start`.
      Default value: `offset` if `model_coefficients is None`, and
      `tfp.math.matvecmul(model_matrix, model_coefficients_start) + offset`
      otherwise.
    l2_regularizer: Optional scalar `Tensor` representing L2 regularization
      penalty, i.e.,
      `loss(w) = sum{-log p(y[i]|x[i],w) : i=1..n} + l2_regularizer ||w||_2^2`.
      Default value: `None` (i.e., no L2 regularization).
    dispersion: Optional (batch of) `Tensor` representing `response` dispersion,
      i.e., as in, `p(y|theta) := exp((y theta - A(theta)) / dispersion)`.
      Must broadcast with rows of `model_matrix`.
      Default value: `None` (i.e., "no dispersion").
    offset: Optional `Tensor` representing constant shift applied to
      `predicted_linear_response`.  Must broadcast to `response`.
      Default value: `None` (i.e., `tf.zeros_like(response)`).
    learning_rate: Optional (batch of) scalar `Tensor` used to dampen iterative
      progress. Typically only needed if optimization diverges, should be no
      larger than `1` and typically very close to `1`.
      Default value: `None` (i.e., `1`).
    fast_unsafe_numerics: Optional Python `bool` indicating if solve should be
      based on Cholesky or QR decomposition.
      Default value: `True` (i.e., "prefer speed via Cholesky decomposition").
    name: Python `str` used as name prefix to ops created by this function.
      Default value: `"fit_one_step"`.

  Returns:
    model_coefficients: (Batch of) vector-shaped `Tensor`; represents the
      next estimate of the model coefficients, one for each column in
      `model_matrix`.
    predicted_linear_response: `response`-shaped `Tensor` representing linear
      predictions based on new `model_coefficients`, i.e.,
      `tfp.math.matvecmul(model_matrix, model_coefficients_next) + offset`.
  """
    graph_deps = [
        model_matrix, response, model_coefficients_start,
        predicted_linear_response_start, dispersion, learning_rate
    ]
    with tf.name_scope(name, 'fit_one_step', graph_deps):

        [
            model_matrix,
            response,
            model_coefficients_start,
            predicted_linear_response_start,
            offset,
        ] = prepare_args(model_matrix, response, model_coefficients_start,
                         predicted_linear_response_start, offset)

        # Compute: mean, grad(mean, predicted_linear_response_start), and variance.
        mean, variance, grad_mean = model(predicted_linear_response_start)

        # If either `grad_mean` or `variance is non-finite or zero, then we'll
        # replace it with a value such that the row is zeroed out. Although this
        # procedure may seem circuitous, it is necessary to ensure this algorithm is
        # itself differentiable.
        is_valid = (tf.is_finite(grad_mean) & tf.not_equal(grad_mean, 0.)
                    & tf.is_finite(variance) & (variance > 0.))

        def mask_if_invalid(x, mask):
            mask = tf.fill(tf.shape(x),
                           value=np.array(mask, x.dtype.as_numpy_dtype))
            return tf.where(is_valid, x, mask)

        # Run one step of iteratively reweighted least-squares.
        # Compute "`z`", the adjusted predicted linear response.
        # z = predicted_linear_response_start
        #     + learning_rate * (response - mean) / grad_mean
        z = (response - mean) / mask_if_invalid(grad_mean, 1.)
        # TODO(jvdillon): Rather than use learning rate, we should consider using
        # backtracking line search.
        if learning_rate is not None:
            z *= learning_rate[..., tf.newaxis]
        z += predicted_linear_response_start

        # Compute "`w`", the per-sample weight.
        if dispersion is not None:
            # For convenience, we'll now scale the variance by the dispersion factor.
            variance *= dispersion
        w = (mask_if_invalid(grad_mean, 0.) *
             tf.rsqrt(mask_if_invalid(variance, np.inf)))

        a = model_matrix * w[..., tf.newaxis]
        b = z * w
        # Solve `min{ || A @ model_coefficients - b ||_2**2 : model_coefficients }`
        # where `@` denotes `matmul`.

        if l2_regularizer is None:
            l2_regularizer = np.array(0, a.dtype.as_numpy_dtype)
        else:
            l2_regularizer_ = distributions_util.maybe_get_static_value(
                l2_regularizer, a.dtype.as_numpy_dtype)
            if l2_regularizer_ is not None:
                l2_regularizer = l2_regularizer_

        def _embed_l2_regularization():
            """Adds synthetic observations to implement L2 regularization."""
            # `tf.matrix_solve_ls` does not respect the `l2_regularization` argument
            # when `fast_unsafe_numerics` is `False`. This function  adds synthetic
            # observations to the data to implement the regularization instead.
            # Adding observations `sqrt(l2_regularizer) * I` is mathematically
            # equivalent to adding the term
            # `-l2_regularizer ||coefficients||_2**2` to the log-likelihood.
            num_model_coefficients = num_cols(model_matrix)
            batch_shape = tf.shape(model_matrix)[:-2]
            eye = tf.eye(num_model_coefficients,
                         batch_shape=batch_shape,
                         dtype=a.dtype)
            a_ = tf.concat([a, tf.sqrt(l2_regularizer) * eye], axis=-2)
            b_ = distributions_util.pad(b,
                                        count=num_model_coefficients,
                                        axis=-1,
                                        back=True)
            # Return l2_regularizer=0 since its now embedded.
            l2_regularizer_ = np.array(0, a.dtype.as_numpy_dtype)
            return a_, b_, l2_regularizer_

        a, b, l2_regularizer = smart_cond.smart_cond(
            smart_reduce_all([not (fast_unsafe_numerics),
                              l2_regularizer > 0.]), _embed_l2_regularization,
            lambda: (a, b, l2_regularizer))

        model_coefficients_next = tf.matrix_solve_ls(
            a,
            b[..., tf.newaxis],
            fast=fast_unsafe_numerics,
            l2_regularizer=l2_regularizer,
            name='model_coefficients_next')
        model_coefficients_next = model_coefficients_next[..., 0]

        # TODO(b/79122261): The approach used in `matrix_solve_ls` could be made
        # faster by avoiding explicitly forming Q and instead keeping the
        # factorization in 'implicit' form with stacked (rescaled) Householder
        # vectors underneath the 'R' and then applying the (accumulated)
        # reflectors in the appropriate order to apply Q'. However, we don't
        # presently do this because we lack core TF functionality. For reference,
        # the vanilla QR approach is:
        #   q, r = tf.linalg.qr(a)
        #   c = tf.matmul(q, b, adjoint_a=True)
        #   model_coefficients_next = tf.matrix_triangular_solve(
        #       r, c, lower=False, name='model_coefficients_next')

        predicted_linear_response_next = calculate_linear_predictor(
            model_matrix,
            model_coefficients_next,
            offset,
            name='predicted_linear_response_next')

        return model_coefficients_next, predicted_linear_response_next
 def testSmartCondMissingArg2(self):
   with ops.Graph().as_default():
     with session.Session():
       x = constant_op.constant(1)
       with self.assertRaises(TypeError):
         smart_cond.smart_cond(True, lambda: x)
Example #32
0
        def _loop_body(
                iter_,
                x_update_diff_norm_sq,
                x_update,  # pylint: disable=missing-docstring
                hess_matmul_x_update):
            # Inner loop of GLMNet's minimizer.
            #
            # This loop updates a single coordinate of x_update.  Ideally, an
            # iteration of this loop would set
            #
            #   x_update[j] += argmin{ LocalLoss(x_update + z*e_j) : z in R }
            #
            # where
            #
            #   LocalLoss(x_update')
            #     = LocalLossSmoothComponent(x_update')
            #         + l1_regularizer * (||x_start + x_update'||_1 -
            #                             ||x_start + x_update||_1)
            #    := (UnregularizedLoss(x_start + x_update') -
            #        UnregularizedLoss(x_start + x_update)
            #         + l2_regularizer * (||x_start + x_update'||_2**2 -
            #                             ||x_start + x_update||_2**2)
            #         + l1_regularizer * (||x_start + x_update'||_1 -
            #                             ||x_start + x_update||_1)
            #
            # In this algorithm approximate the above argmin using (univariate)
            # proximal gradient descent:
            #
            # (*)  x_update[j] = prox_{t * l1_regularizer * L1}(
            #                 x_update[j] -
            #                 t * d/dz|z=0 UnivariateLocalLossSmoothComponent(z))
            #
            # where
            #
            #   UnivariateLocalLossSmoothComponent(z)
            #       := LocalLossSmoothComponent(x_update + z*e_j)
            #
            # and we approximate
            #
            #       d/dz UnivariateLocalLossSmoothComponent(z)
            #     = grad LocalLossSmoothComponent(x_update))[j]
            #    ~= (grad LossSmoothComponent(x_start)
            #         + x_update matmul HessianOfLossSmoothComponent(x_start))[j].
            #
            # To choose the parameter t, we squint and pretend that the inner term of
            # (*) is a Newton update as if we were using Newton's method to minimize
            # UnivariateLocalLossSmoothComponent.  That is, we choose t such that
            #
            #   -t * d/dz ULLSC = -learning_rate * (d/dz ULLSC) / (d^2/dz^2 ULLSC)
            #
            # at z=0.  Hence
            #
            #   t = learning_rate / (d^2/dz^2|z=0 ULLSC)
            #     = learning_rate / HessianOfLossSmoothComponent(
            #                           x_start + x_update)[j,j]
            #    ~= learning_rate / HessianOfLossSmoothComponent(
            #                           x_start)[j,j]
            #
            # The above approximation is equivalent to assuming that
            # HessianOfUnregularizedLoss is constant, i.e., ignoring third-order
            # effects.
            #
            # Note that because LossSmoothComponent is (assumed to be) convex, t is
            # positive.

            # In above notation, coord = j.
            coord = iter_ % dims
            # x_update_diff_norm_sq := ||x_update_end - x_update_start||_2**2,
            # computed incrementally, where x_update_end and x_update_start are as
            # defined in the convergence criteria.  Accordingly, we reset
            # x_update_diff_norm_sq to zero at the beginning of each sweep.
            x_update_diff_norm_sq = tf.where(
                tf.equal(coord, 0), tf.zeros_like(x_update_diff_norm_sq),
                x_update_diff_norm_sq)

            w_old = x_start[coord] + x_update[coord]
            # This is the coordinatwise Newton update if no L1 regularization.
            # In above notation, newton_step = -t * (approximation of d/dz|z=0 ULLSC).
            second_deriv = _hessian_diag_elt_with_l2(coord)
            newton_step = -_mul_ignoring_nones(  # pylint: disable=invalid-unary-operand-type
                learning_rate, grad_loss_with_l2[coord] +
                hess_matmul_x_update[coord]) / second_deriv
            # Applying the soft-threshold operator accounts for L1 regularization.
            # In above notation, delta =
            #     prox_{t*l1_regularizer*L1}(w_old + newton_step) - w_old.
            delta = (soft_threshold(
                w_old + newton_step,
                _mul_ignoring_nones(learning_rate, l1_regularizer) /
                second_deriv) - w_old)

            def _do_update(x_update_diff_norm_sq, x_update,
                           hess_matmul_x_update):  # pylint: disable=missing-docstring
                del x_update
                hessian_column_with_l2 = _sparse_or_dense_matvecmul(
                    hessian_unregularized_loss_outer,
                    hessian_unregularized_loss_middle *
                    _sparse_or_dense_matmul_onehot(
                        hessian_unregularized_loss_outer, coord, num_samples),
                    adjoint_a=True)
                if l2_regularizer is not None:
                    hessian_column_with_l2 += _one_hot_like(
                        hessian_column_with_l2,
                        coord,
                        on_value=2. * l2_regularizer)
                changed_x_update_var = tf.scatter_update(
                    x_update_var, [coord], [x_update_var[coord] + delta])
                with tf.control_dependencies([changed_x_update_var]):
                    x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2.
                    hess_matmul_x_update_ = (hess_matmul_x_update +
                                             delta * hessian_column_with_l2)
                    return [
                        x_update_diff_norm_sq_, changed_x_update_var,
                        hess_matmul_x_update_
                    ]

            inputs_to_update = [
                x_update_diff_norm_sq, x_update, hess_matmul_x_update
            ]
            return [iter_ + 1] + smart_cond.smart_cond(
                # Note on why checking delta (a difference of floats) for equality to
                # zero is ok:
                #
                # First of all, x - x == 0 in floating point -- see
                # https://stackoverflow.com/a/2686671
                #
                # Delta will conceptually equal zero when one of the following holds:
                # (i)   |w_old + newton_step| <= threshold and w_old == 0
                # (ii)  |w_old + newton_step| > threshold and
                #       w_old + newton_step - sign(w_old + newton_step) * threshold
                #          == w_old
                #
                # In case (i) comparing delta to zero is fine.
                #
                # In case (ii), newton_step conceptually equals
                #     sign(w_old + newton_step) * threshold.
                # Also remember
                #     threshold = -newton_step / (approximation of d/dz|z=0 ULLSC).
                # So (i) happens when
                #     (approximation of d/dz|z=0 ULLSC) == -sign(w_old + newton_step).
                # If we did not require LossSmoothComponent to be strictly convex,
                # then this could actually happen a non-negligible amount of the time,
                # e.g. if the loss function is piecewise linear and one of the pieces
                # has slope 1.  But since LossSmoothComponent is strictly convex, (i)
                # should not systematically happen.
                tf.equal(delta, 0.),
                lambda: inputs_to_update,
                lambda: _do_update(*inputs_to_update))
Example #33
0
        def _body(
                converged,  # pylint: disable=unused-argument
                stopped,  # pylint: disable=unused-argument
                iteration,
                total_evals,
                position,
                objective_value,
                objective_gradient,
                input_inv_hessian_estimate):
            """Main optimization loop."""

            search_direction = _get_search_direction(
                input_inv_hessian_estimate, objective_gradient)
            derivative_at_start_pt = tf.reduce_sum(objective_gradient *
                                                   search_direction)
            # If the derivative at the start point is not negative, reset the
            # Hessian estimate and recompute the search direction.
            needs_reset = derivative_at_start_pt >= 0

            def _reset_search_dirn():
                search_direction = _get_search_direction(
                    initial_inv_hessian, objective_gradient)
                return search_direction, initial_inv_hessian

            search_direction, inv_hessian_estimate = smart_cond.smart_cond(
                needs_reset,
                true_fn=_reset_search_dirn,
                false_fn=lambda:
                (search_direction, input_inv_hessian_estimate))
            line_search_value_grad_func = _restrict_along_direction(
                value_and_gradients_function, position, search_direction)
            derivative_at_start_pt = tf.reduce_sum(objective_gradient *
                                                   search_direction)

            ls_result = linesearch.hager_zhang(
                line_search_value_grad_func,
                initial_step_size=tf.convert_to_tensor(1, dtype=dtype),
                objective_at_zero=objective_value,
                grad_objective_at_zero=derivative_at_start_pt)

            # Fail if the objective value is not finite or the line search failed.
            ls_failed = ~ls_result.converged

            # If the line search failed, then quit at this point.
            def _failed_fn():
                """Line search failed action."""
                failed_retval = BfgsOptimizerResults(
                    converged=False,
                    failed=True,
                    num_iterations=iteration + 1,
                    num_objective_evaluations=total_evals +
                    ls_result.func_evals,
                    position=position,
                    objective_value=objective_value,
                    objective_gradient=objective_gradient,
                    inverse_hessian_estimate=inv_hessian_estimate)
                return failed_retval

            def _success_fn():
                return _bfgs_update(value_and_gradients_function, position,
                                    objective_value, objective_gradient,
                                    search_direction, inv_hessian_estimate,
                                    ls_result.left_pt, iteration,
                                    total_evals + ls_result.func_evals,
                                    tolerance, f_relative_tolerance,
                                    x_tolerance)

            return smart_cond.smart_cond(ls_failed,
                                         true_fn=_failed_fn,
                                         false_fn=_success_fn)
def write(tag, tensor, step=None, metadata=None, name=None):
    """Writes a generic summary to the default SummaryWriter if one exists.

  This exists primarily to support the definition of type-specific summary ops
  like scalar() and image(), and is not intended for direct use unless defining
  a new type-specific summary op.

  Args:
    tag: string tag used to identify the summary (e.g. in TensorBoard), usually
      generated with `tf.summary.summary_scope`
    tensor: the Tensor holding the summary data to write or a callable that
      returns this Tensor. If a callable is passed, it will only be called when
      a default SummaryWriter exists and the recording condition specified by
      `record_if()` is met.
    step: Explicit `int64`-castable monotonic step value for this summary. If
      omitted, this defaults to `tf.summary.experimental.get_step()`, which must
      not be None.
    metadata: Optional SummaryMetadata, as a proto or serialized bytes
    name: Optional string name for this op.

  Returns:
    True on success, or false if no summary was written because no default
    summary writer was available.

  Raises:
    ValueError: if a default writer exists, but no step was provided and
      `tf.summary.experimental.get_step()` is None.
  """
    with ops.name_scope(name, "write_summary") as scope:
        if _summary_state.writer is None:
            return constant_op.constant(False)
        if step is None:
            step = get_step()
            if step is None:
                raise ValueError("No step set via 'step' argument or "
                                 "tf.summary.experimental.set_step()")
        if metadata is None:
            serialized_metadata = b""
        elif hasattr(metadata, "SerializeToString"):
            serialized_metadata = metadata.SerializeToString()
        else:
            serialized_metadata = metadata

        def record():
            """Record the actual summary and return True."""
            # Note the identity to move the tensor to the CPU.
            with ops.device("cpu:0"):
                summary_tensor = tensor() if callable(
                    tensor) else array_ops.identity(tensor)
                write_summary_op = gen_summary_ops.write_summary(
                    _summary_state.writer._resource,  # pylint: disable=protected-access
                    step,
                    summary_tensor,
                    tag,
                    serialized_metadata,
                    name=scope)
                with ops.control_dependencies([write_summary_op]):
                    return constant_op.constant(True)

        op = smart_cond.smart_cond(_should_record_summaries_v2(),
                                   record,
                                   _nothing,
                                   name="summary_cond")
        if not context.executing_eagerly():
            ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op)  # pylint: disable=protected-access
        return op
Example #35
0
 def x_permuted():
   return tf.transpose(
       x,
       perm=smart_cond.smart_cond(source_idx < dest_idx,
                                  move_right_permutation,
                                  move_left_permutation))
Example #36
0
def _bfgs_inv_hessian_update(grad_delta, position_delta, inv_hessian_estimate):
    """Applies the BFGS update to the inverse Hessian estimate.

  The BFGS update rule is (note A^T denotes the transpose of a vector/matrix A).

  ```None
    rho = 1/(grad_delta^T * position_delta)
    U = (I - rho * position_delta * grad_delta^T)
    H_1 =  U * H_0 * U^T + rho * position_delta * position_delta^T
  ```

  Here, `H_0` is the inverse Hessian estimate at the previous iteration and
  `H_1` is the next estimate. Note that `*` should be interpreted as the
  matrix multiplication (with the understanding that matrix multiplication for
  scalars is usual multiplication and for matrix with vector is the action of
  the matrix on the vector.).

  The implementation below utilizes an expanded version of the above formula
  to avoid the matrix multiplications that would be needed otherwise. By
  expansion it is easy to see that one only needs matrix-vector or
  vector-vector operations. The expanded version is:

  ```None
    f = 1 + rho * (grad_delta^T * H_0 * grad_delta)
    H_1 - H_0 = - rho * [position_delta * (H_0 * grad_delta)^T +
                        (H_0 * grad_delta) * position_delta^T] +
                  rho * f * [position_delta * position_delta^T]
  ```

  All the terms in square brackets are matrices and are constructed using
  vector outer products. All the other terms on the right hand side are scalars.
  Also worth noting that the first and second lines are both rank 1 updates
  applied to the current inverse Hessian estimate.

  Args:
    grad_delta: `Tensor` of real dtype and same shape as `position_delta`.
      The difference between the gradient at the new position and the old
      position.
    position_delta: `Tensor` of real dtype and nonzero rank. The change in
      position from the previous iteration to the current one.
    inv_hessian_estimate: `Tensor` of real dtype and shape equal to
      the shape of `position_delta` concatenated with itself. If the shape of
      `position_delta` is [n1, n2,.., nr] then the shape of
      `inv_hessian_estimate` should be
      `[n1, n2, ..., nr, n1, n2, ..., nr]. The previous estimate of the
      inverse Hessian. Should be positive definite and symmetric.

  Returns:
    A tuple containing the following fields
    is_valid: Boolean `Tensor` indicating whether the update succeeded. The
      update can fail if the position change becomes orthogonal to the gradient
      change.
    next_inv_hessian_estimate: `Tensor` of same shape and dtype as
      `inv_hessian_estimate`. The next Hessian estimate updated using the
      BFGS update scheme. If the `inv_hessian_estimate` is symmetric and
      positive definite, the `next_inv_hessian_estimate` is guaranteed to
      satisfy the same conditions.
  """
    # The normalization term (y^T . s)
    normalization_factor = tf.reduce_sum(grad_delta * position_delta)

    is_singular = tf.equal(normalization_factor, 0)

    def _is_singular_fn():
        """If the update is singular, returns the old value."""
        return inv_hessian_estimate  # Return the old value

    def _do_update_fn():
        """Updates the Hessian estimate."""
        # The quadratic form: y^T.H.y.
        n = len(grad_delta.shape.as_list())
        contraction_axes = np.arange(-n, 0)
        # H.y where H is the inverse Hessian and y is the gradient change.
        conditioned_grad_delta = tf.tensordot(
            inv_hessian_estimate,
            grad_delta,
            axes=[contraction_axes, contraction_axes])
        conditioned_grad_delta_norm = tf.reduce_sum(conditioned_grad_delta *
                                                    grad_delta)

        # The first rank 1 update term requires the outer product: s.y^T.
        # We leverage broadcasting to do this in a shape agnostic manner.
        # The position delta and the grad delta have the same rank, say, `n`. We
        # adjust the shape of the position delta by adding extra 'n' dimensions to
        # the right so its `padded` shape is original_shape + ([1] * n).
        cross_term = _tensor_product(position_delta, conditioned_grad_delta)
        # Symmetrize
        cross_term += _tensor_product(conditioned_grad_delta, position_delta)
        position_term = _tensor_product(position_delta, position_delta)
        with tf.control_dependencies([position_term]):
            position_term *= (
                1 + conditioned_grad_delta_norm / normalization_factor)

        next_inv_hessian_estimate = (
            inv_hessian_estimate +
            (position_term - cross_term) / normalization_factor)
        return next_inv_hessian_estimate

    next_estimate = smart_cond.smart_cond(is_singular,
                                          true_fn=_is_singular_fn,
                                          false_fn=_do_update_fn)

    return next_estimate
Example #37
0
    def call(self, inputs, training=None):

        # numpy value, mark the layer is in training
        training = self.batchnorm._get_training_value(training)  # pylint: disable=protected-access

        # checking if to update batchnorm params
        bn_training = tf.math.logical_and(
            training, tf.math.less_equal(self._iteration,
                                         self.ema_freeze_delay))

        kernel = self.kernel

        # run conv to produce output for the following batchnorm
        conv_outputs = tf.keras.backend.conv2d(
            inputs,
            kernel,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate)

        if self.use_bias:
            bias = self.bias
            conv_outputs = tf.keras.backend.bias_add(
                conv_outputs, bias, data_format=self.data_format)
        else:
            bias = 0

        _ = self.batchnorm(conv_outputs, training=bn_training)
        if training is True:
            # The following operation is only performed during training

            self._iteration.assign_add(
                tf_utils.smart_cond(training, lambda: tf.constant(1, tf.int64),
                                    lambda: tf.constant(0, tf.int64)))

            # calcuate mean and variance from current batch
            bn_shape = conv_outputs.shape
            ndims = len(bn_shape)
            reduction_axes = [
                i for i in range(ndims) if i not in self.batchnorm.axis
            ]
            keep_dims = len(self.batchnorm.axis) > 1
            mean, variance = self.batchnorm._moments(  # pylint: disable=protected-access
                math_ops.cast(conv_outputs, self.batchnorm._param_dtype),  # pylint: disable=protected-access
                reduction_axes,
                keep_dims=keep_dims)
            # get batchnorm weights
            gamma = self.batchnorm.gamma
            beta = self.batchnorm.beta
            moving_mean = self.batchnorm.moving_mean
            moving_variance = self.batchnorm.moving_variance

            if self.folding_mode == "batch_stats_folding":
                # using batch mean and variance in the initial training stage
                # after sufficient training, switch to moving mean and variance
                new_mean = tf_utils.smart_cond(bn_training, lambda: mean,
                                               lambda: moving_mean)
                new_variance = tf_utils.smart_cond(bn_training,
                                                   lambda: variance,
                                                   lambda: moving_variance)

                # get the inversion factor so that we replace division by multiplication
                inv = math_ops.rsqrt(new_variance + self.batchnorm.epsilon)
                if gamma is not None:
                    inv *= gamma
                # fold bias with bn stats
                folded_bias = inv * (bias - new_mean) + beta

            elif self.folding_mode == "ema_stats_folding":
                # We always scale the weights with a correction factor to the long term
                # statistics prior to quantization. This ensures that there is no jitter
                # in the quantized weights due to batch to batch variation. During the
                # initial phase of training, we undo the scaling of the weights so that
                # outputs are identical to regular batch normalization. We also modify
                # the bias terms correspondingly. After sufficient training, switch from
                # using batch statistics to long term moving averages for batch
                # normalization.

                # use batch stats for calcuating bias before bn freeze, and use moving
                # stats after bn freeze
                mv_inv = math_ops.rsqrt(moving_variance +
                                        self.batchnorm.epsilon)
                batch_inv = math_ops.rsqrt(variance + self.batchnorm.epsilon)

                if gamma is not None:
                    mv_inv *= gamma
                    batch_inv *= gamma
                folded_bias = tf_utils.smart_cond(
                    bn_training, lambda: batch_inv * (bias - mean) + beta,
                    lambda: mv_inv * (bias - moving_mean) + beta)
                # moving stats is always used to fold kernel in tflite; before bn freeze
                # an additional correction factor will be applied to the conv2d output
                inv = mv_inv
            else:
                assert ValueError

            # wrap conv kernel with bn parameters
            folded_kernel = inv * kernel
            # quantize the folded kernel
            if self.kernel_quantizer is not None:
                q_folded_kernel = self.kernel_quantizer_internal(folded_kernel)
            else:
                q_folded_kernel = folded_kernel

            # If loaded from a ckpt, bias_quantizer is the ckpt value
            # Else if bias_quantizer not specified, bias
            #   quantizer is None and we need to calculate bias quantizer
            #   type according to accumulator type. User can call
            #   bn_folding_utils.populate_bias_quantizer_from_accumulator(
            #      model, input_quantizer_list]) to populate such bias quantizer.
            if self.bias_quantizer_internal is not None:
                q_folded_bias = self.bias_quantizer_internal(folded_bias)
            else:
                q_folded_bias = folded_bias

            # set value for the folded weights
            self.folded_kernel_quantized.assign(q_folded_kernel,
                                                read_value=False)
            self.folded_bias_quantized.assign(q_folded_bias, read_value=False)

            applied_kernel = q_folded_kernel
            applied_bias = q_folded_bias
        else:
            applied_kernel = self.folded_kernel_quantized
            applied_bias = self.folded_bias_quantized
        # calculate conv2d output using the quantized folded kernel
        folded_outputs = tf.keras.backend.conv2d(
            inputs,
            applied_kernel,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate)
        if training is True and self.folding_mode == "ema_stats_folding":
            batch_inv = math_ops.rsqrt(variance + self.batchnorm.epsilon)
            y_corr = tf_utils.smart_cond(
                bn_training, lambda:
                (math_ops.sqrt(moving_variance + self.batchnorm.epsilon) *
                 math_ops.rsqrt(variance + self.batchnorm.epsilon)),
                lambda: tf.constant(1.0, shape=moving_variance.shape))
            folded_outputs = math_ops.mul(folded_outputs, y_corr)

        folded_outputs = tf.keras.backend.bias_add(
            folded_outputs, applied_bias, data_format=self.data_format)
        if self.activation is not None:
            return self.activation(folded_outputs)

        return folded_outputs
Example #38
0
    def call(self, inputs, training=False):
        x = inputs
        training = training and self.trainable
        self.will_ema_freeze = self.will_ema_freeze and self.trainable

        # Update the step count if the optimizer step count is unknown
        self.step.assign_add(
            K.switch(
                tf.math.logical_and(self.is_estimating_step_count, training),
                tf.constant(1, tf.int64), tf.constant(0, tf.int64)))

        # Perform the quantization
        if training:
            # Calculate the qnoise, a scalar from 0 to 1 that represents the level of
            # quantization noise to use. At training start, we want no quantization,
            # so qnoise_factor = 0.0. After quantization_delay steps, we want normal
            # quantization, so qnoise_factor = 1.0.
            qnoise_factor = K.switch(
                tf.greater_equal(self.step, self.quantization_delay),
                lambda: tf.constant(1.0), lambda: tf.constant(0.0))
            self.quantizer.update_qnoise_factor(qnoise_factor)
            qx = self.quantizer(x)

        else:  # If not training, we always want to use full quantization
            self.quantizer.update_qnoise_factor(tf.constant(1.0))
            qx = self.quantizer(x)

        # Calculate the axis along where to find the min and max EMAs
        len_axis = len(x.shape)
        if len_axis > 1:
            if self.per_channel:
                if K.image_data_format() == "channels_last":
                    axis = list(range(len_axis - 1))
                else:
                    axis = list(range(1, len_axis))
            else:
                axis = list(range(len_axis))
        else:
            axis = [0]

        # Determine if freezing the EMA
        is_ema_training = tf.constant(training, dtype=tf.bool)
        if self.will_ema_freeze:
            is_ema_training = tf.cond(
                tf.greater(self.step, self.ema_freeze_delay),
                lambda: tf.constant(False), lambda: tf.constant(True))

        def update_branch():
            """ Update the moving average when is_ema_training is True."""

            # Set the qnoise factor to 0 to update the EMA using the unquantized input
            prev_qnoise_factor = tf.identity(self.quantizer.qnoise_factor)
            self.quantizer.update_qnoise_factor(tf.constant(0.0))

            # Update the EMA
            act_x = self.quantizer(
                x)  # act_x is the input after the activation
            # function, but before the quantizer. This is
            # done by using a qnoise_factor of 0
            new_min = tf.squeeze(K.min(act_x, axis=axis, keepdims=True))
            K.moving_average_update(self.ema_min, new_min, self.ema_decay)
            new_max = tf.squeeze(K.max(act_x, axis=axis, keepdims=True))
            K.moving_average_update(self.ema_max, new_max, self.ema_decay)

            # Reset the qnoise factor to the previous value
            self.quantizer.update_qnoise_factor(prev_qnoise_factor)

        # Update the moving average when is_ema_training is True
        tf_utils.smart_cond(is_ema_training,
                            true_fn=update_branch,
                            false_fn=lambda: None)

        # Set the integer bits for the quantizer
        integer_bits = _get_integer_bits(min_value=self.ema_min,
                                         max_value=self.ema_max,
                                         bits=self.total_bits,
                                         symmetric=self.symmetric,
                                         keep_negative=self.keep_negative,
                                         is_clipping=self.po2_rounding)
        self.quantizer.integer.assign(integer_bits)

        return qx