def testCond(self): with ops.Graph().as_default(): pred = array_ops.placeholder_with_default(True, shape=()) x = control_flow_ops.cond(pred, lambda: constant_op.constant(1), lambda: constant_op.constant(2)) self.assertIsNone(smart_cond.smart_constant_value(x))
def constant_value(pred): """Return the bool value for `pred`, or None if `pred` had a dynamic value. Arguments: pred: A scalar, either a Python bool or a TensorFlow boolean variable or tensor, or the Python integer 1 or 0. Returns: True or False if `pred` has a constant boolean value, None otherwise. Raises: TypeError: If `pred` is not a Variable, Tensor or bool, or Python integer 1 or 0. """ # Allow integer booleans. if isinstance(pred, int): if pred == 1: pred = True elif pred == 0: pred = False if isinstance(pred, variables.Variable): return None return smart_module.smart_constant_value(pred)
def constant_value(pred): """Return the bool value for `pred`, or None if `pred` had a dynamic value. Arguments: pred: A scalar, either a Python bool or a TensorFlow boolean variable or tensor, or the Python integer 1 or 0. Returns: True or False if `pred` has a constant boolean value, None otherwise. Raises: TypeError: If `pred` is not a Variable, Tensor or bool, or Python integer 1 or 0. """ # Allow integer booleans. if isinstance(pred, int): if pred == 1: pred = True elif pred == 0: pred = False if isinstance(pred, variables.Variable): return None return smart_module.smart_constant_value(pred)
def testCond(self): with ops.Graph().as_default(): pred = array_ops.placeholder_with_default(True, shape=()) x = control_flow_ops.cond(pred, lambda: constant_op.constant(1), lambda: constant_op.constant(2)) self.assertIsNone(smart_cond.smart_constant_value(x))
def call(self, inputs, training=None): if self.scale and self.gamma_quantizer: quantized_gamma = self.gamma_quantizer_internal(self.gamma) else: quantized_gamma = self.gamma if self.center and self.beta_quantizer: quantized_beta = self.beta_quantizer_internal(self.beta) else: quantized_beta = self.beta if self.mean_quantizer: quantized_moving_mean = self.mean_quantizer_internal( self.moving_mean) else: quantized_moving_mean = self.moving_mean if self.variance_quantizer: quantized_moving_variance = self.variance_quantizer_internal( self.moving_variance) else: quantized_moving_variance = self.moving_variance training = self._get_training_value(training) # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(quantized_gamma), _broadcast(quantized_beta) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.smart_constant_value(training) if training_value == False: # pylint: disable=singleton-comparison,g-explicit-bool-comparison quantized_mean, quantized_variance = (quantized_moving_mean, quantized_moving_variance) else: # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = len(self.axis) > 1 mean, variance = self._moments(math_ops.cast( inputs, self._param_dtype), reduction_axes, keep_dims=keep_dims) moving_mean = self.moving_mean moving_variance = self.moving_variance mean = tf_utils.smart_cond( training, lambda: mean, lambda: ops.convert_to_tensor(moving_mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: ops.convert_to_tensor(moving_variance)) new_mean, new_variance = mean, variance if self.mean_quantizer: quantized_mean = self.mean_quantizer_internal(mean) else: quantized_mean = mean if self.variance_quantizer: quantized_variance = self.variance_quantizer_internal(variance) else: quantized_variance = variance if self._support_zero_size_input(): inputs_size = array_ops.size(inputs) else: inputs_size = None def _do_update(var, value): """Compute the updates for mean and variance.""" return self._assign_moving_average(var, value, self.momentum, inputs_size) def mean_update(): true_branch = lambda: _do_update(self.moving_mean, new_mean) false_branch = lambda: self.moving_mean return tf_utils.smart_cond(training, true_branch, false_branch) def variance_update(): """Update the moving variance.""" true_branch = lambda: _do_update(self.moving_variance, new_variance) false_branch = lambda: self.moving_variance return tf_utils.smart_cond(training, true_branch, false_branch) self.add_update(mean_update) self.add_update(variance_update) quantized_mean = math_ops.cast(quantized_mean, inputs.dtype) quantized_variance = math_ops.cast(quantized_variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) if scale is not None: scale = math_ops.cast(scale, inputs.dtype) # TODO(reedwm): Maybe do math in float32 if given float16 inputs, if doing # math in float16 hurts validation accuracy of popular models like resnet. outputs = nn.batch_normalization(inputs, _broadcast(quantized_mean), _broadcast(quantized_variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
def _padded_shape_to_batch_shape(s): return tensor_shape.TensorShape([ tensor_util.constant_value(self._batch_size) if smart_cond.smart_constant_value(self._drop_remainder) else None ]).concatenate(tensor_util.constant_value_as_shape(s))