Ejemplo n.º 1
0
 def testCond(self):
   with ops.Graph().as_default():
     pred = array_ops.placeholder_with_default(True, shape=())
     x = control_flow_ops.cond(pred,
                               lambda: constant_op.constant(1),
                               lambda: constant_op.constant(2))
     self.assertIsNone(smart_cond.smart_constant_value(x))
Ejemplo n.º 2
0
def constant_value(pred):
  """Return the bool value for `pred`, or None if `pred` had a dynamic value.

  Arguments:
    pred: A scalar, either a Python bool or a TensorFlow boolean variable
      or tensor, or the Python integer 1 or 0.

  Returns:
    True or False if `pred` has a constant boolean value, None otherwise.

  Raises:
    TypeError: If `pred` is not a Variable, Tensor or bool, or Python
      integer 1 or 0.
  """
  # Allow integer booleans.
  if isinstance(pred, int):
    if pred == 1:
      pred = True
    elif pred == 0:
      pred = False

  if isinstance(pred, variables.Variable):
    return None
  return smart_module.smart_constant_value(pred)
Ejemplo n.º 3
0
def constant_value(pred):
  """Return the bool value for `pred`, or None if `pred` had a dynamic value.

  Arguments:
    pred: A scalar, either a Python bool or a TensorFlow boolean variable
      or tensor, or the Python integer 1 or 0.

  Returns:
    True or False if `pred` has a constant boolean value, None otherwise.

  Raises:
    TypeError: If `pred` is not a Variable, Tensor or bool, or Python
      integer 1 or 0.
  """
  # Allow integer booleans.
  if isinstance(pred, int):
    if pred == 1:
      pred = True
    elif pred == 0:
      pred = False

  if isinstance(pred, variables.Variable):
    return None
  return smart_module.smart_constant_value(pred)
Ejemplo n.º 4
0
 def testCond(self):
     with ops.Graph().as_default():
         pred = array_ops.placeholder_with_default(True, shape=())
         x = control_flow_ops.cond(pred, lambda: constant_op.constant(1),
                                   lambda: constant_op.constant(2))
         self.assertIsNone(smart_cond.smart_constant_value(x))
Ejemplo n.º 5
0
    def call(self, inputs, training=None):
        if self.scale and self.gamma_quantizer:
            quantized_gamma = self.gamma_quantizer_internal(self.gamma)
        else:
            quantized_gamma = self.gamma

        if self.center and self.beta_quantizer:
            quantized_beta = self.beta_quantizer_internal(self.beta)
        else:
            quantized_beta = self.beta

        if self.mean_quantizer:
            quantized_moving_mean = self.mean_quantizer_internal(
                self.moving_mean)
        else:
            quantized_moving_mean = self.moving_mean

        if self.variance_quantizer:
            quantized_moving_variance = self.variance_quantizer_internal(
                self.moving_variance)
        else:
            quantized_moving_variance = self.moving_variance

        training = self._get_training_value(training)

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.shape
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.shape) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(quantized_gamma), _broadcast(quantized_beta)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.smart_constant_value(training)
        if training_value == False:  # pylint: disable=singleton-comparison,g-explicit-bool-comparison
            quantized_mean, quantized_variance = (quantized_moving_mean,
                                                  quantized_moving_variance)
        else:
            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = len(self.axis) > 1
            mean, variance = self._moments(math_ops.cast(
                inputs, self._param_dtype),
                                           reduction_axes,
                                           keep_dims=keep_dims)

            moving_mean = self.moving_mean
            moving_variance = self.moving_variance

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: ops.convert_to_tensor(moving_mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance,
                lambda: ops.convert_to_tensor(moving_variance))

            new_mean, new_variance = mean, variance

            if self.mean_quantizer:
                quantized_mean = self.mean_quantizer_internal(mean)
            else:
                quantized_mean = mean

            if self.variance_quantizer:
                quantized_variance = self.variance_quantizer_internal(variance)
            else:
                quantized_variance = variance

            if self._support_zero_size_input():
                inputs_size = array_ops.size(inputs)
            else:
                inputs_size = None

            def _do_update(var, value):
                """Compute the updates for mean and variance."""
                return self._assign_moving_average(var, value, self.momentum,
                                                   inputs_size)

            def mean_update():
                true_branch = lambda: _do_update(self.moving_mean, new_mean)
                false_branch = lambda: self.moving_mean
                return tf_utils.smart_cond(training, true_branch, false_branch)

            def variance_update():
                """Update the moving variance."""
                true_branch = lambda: _do_update(self.moving_variance,
                                                 new_variance)
                false_branch = lambda: self.moving_variance
                return tf_utils.smart_cond(training, true_branch, false_branch)

            self.add_update(mean_update)
            self.add_update(variance_update)

        quantized_mean = math_ops.cast(quantized_mean, inputs.dtype)
        quantized_variance = math_ops.cast(quantized_variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        if scale is not None:
            scale = math_ops.cast(scale, inputs.dtype)
        # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing
        # math in float16 hurts validation accuracy of popular models like resnet.
        outputs = nn.batch_normalization(inputs, _broadcast(quantized_mean),
                                         _broadcast(quantized_variance),
                                         offset, scale, self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        return outputs
Ejemplo n.º 6
0
 def _padded_shape_to_batch_shape(s):
   return tensor_shape.TensorShape([
       tensor_util.constant_value(self._batch_size)
       if smart_cond.smart_constant_value(self._drop_remainder) else None
   ]).concatenate(tensor_util.constant_value_as_shape(s))