Example #1
0
 def test_constant(self):
   for v in [True, False, 1, 0, 1.0]:
     c = constant_op.constant(v)
     value = utils.constant_value(c)
     self.assertEqual(value, v)
     with self.test_session():
       self.assertEqual(c.eval(), v)
    def _build_update_ops_variance(self, mean, variance, is_training):
        def build_update_ops():
            update_mean_op = moving_averages.assign_moving_average(
                variable=self._moving_mean,
                value=mean,
                decay=self._decay_rate,
                name="update_moving_mean").op

            update_variance_op = moving_averages.assign_moving_average(
                variable=self._moving_variance,
                value=variance,
                decay=self._decay_rate,
                name="update_moving_variance").op

            return update_mean_op, update_variance_op

        def build_no_ops():
            return (tf.no_op(), tf.no_op())

            # Only make the ops if we know that `is_training=True`, or the
            # value of `is_training` is unknown.

        is_training_const = utils.constant_value(is_training)
        if is_training_const is None or is_training_const:
            update_mean_op, update_variance_op = utils.smart_cond(
                is_training,
                build_update_ops,
                build_no_ops,
            )

        # Every new connection creates a new op which adds its contribution
        # to the running average when ran.
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_variance_op)
Example #3
0
 def test_constant(self):
     for v in [True, False, 1, 0, 1.0]:
         c = tf.constant(v)
         value = utils.constant_value(c)
         self.assertEqual(value, v)
         with self.test_session():
             self.assertEqual(c.eval(), v)
    def _build_update_ops_second_moment(self, mean, second_moment,
                                        is_training):
        def build_update_ops():
            update_mean_op = moving_averages.assign_moving_average(
                variable=self._moving_mean,
                value=mean,
                decay=self._decay_rate,
                name="update_moving_mean").op

            update_second_moment_op = moving_averages.assign_moving_average(
                variable=self._moving_second_moment,
                value=second_moment,
                decay=self._decay_rate,
                name="update_moving_second_moment").op

            return update_mean_op, update_second_moment_op

        def build_no_ops():
            return (tf.no_op(), tf.no_op())

        is_training_const = utils.constant_value(is_training)
        if is_training_const is None or is_training_const:
            update_mean_op, update_second_moment_op = utils.smart_cond(
                is_training,
                build_update_ops,
                build_no_ops,
            )

        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_second_moment_op)
Example #5
0
 def test_placeholder(self):
     for v in [True, False, 1, 0, 1.0]:
         p = tf.placeholder(np.dtype(type(v)), [])
         x = tf.identity(p)
         value = utils.constant_value(p)
         self.assertEqual(value, None)
         with self.test_session():
             self.assertEqual(x.eval(feed_dict={p: v}), v)
Example #6
0
 def test_placeholder(self):
   for v in [True, False, 1, 0, 1.0]:
     p = array_ops.placeholder(np.dtype(type(v)), [])
     x = array_ops.identity(p)
     value = utils.constant_value(p)
     self.assertEqual(value, None)
     with self.test_session():
       self.assertEqual(x.eval(feed_dict={p: v}), v)
Example #7
0
 def test_variable(self):
     for v in [True, False, 1, 0, 1.0]:
         with tf.Graph().as_default() as g, self.test_session(g) as sess:
             x = tf.Variable(v)
             value = utils.constant_value(x)
             self.assertEqual(value, None)
             sess.run(tf.global_variables_initializer())
             self.assertEqual(x.eval(), v)
Example #8
0
 def test_variable(self):
   for v in [True, False, 1, 0, 1.0]:
     with ops.Graph().as_default() as g, self.test_session(g) as sess:
       x = variables.Variable(v)
       value = utils.constant_value(x)
       self.assertEqual(value, None)
       sess.run(variables.global_variables_initializer())
       self.assertEqual(x.eval(), v)
Example #9
0
    def _build_update_ops(self, mean, variance, is_training):
        """Builds the moving average update ops when using moving variance.

    Args:
      mean: The mean value to update with.
      variance: The variance value to update with.
      is_training: Boolean Tensor to indicate if we're currently in
        training mode.

    Returns:
      Tuple of `(update_mean_op, update_variance_op)` when `is_training` is or
      could be `True`. Returns `None` when `is_training=False`.
    """
        def build_update_ops():
            """Builds the exponential moving average update ops."""

            update_mean_op = moving_averages.assign_moving_average(
                variable=self._moving_mean,
                value=mean,
                decay=self._decay_rate,
                zero_debias=False,
                name="update_moving_mean").op

            update_variance_op = moving_averages.assign_moving_average(
                variable=self._moving_variance,
                value=variance,
                decay=self._decay_rate,
                zero_debias=False,
                name="update_moving_variance").op

            return update_mean_op, update_variance_op

        def build_no_ops():
            return (tf.no_op(), tf.no_op())

        # Only make the ops if we know that `is_training=True`, or the value of
        # `is_training` is unknown.
        is_training_const = utils.constant_value(is_training)
        if is_training_const is None or is_training_const:
            update_mean_op, update_variance_op = utils.smart_cond(
                is_training,
                build_update_ops,
                build_no_ops,
            )
            return (update_mean_op, update_variance_op)
        else:
            return None
Example #10
0
    def _build_update_ops_second_moment(self, mean, second_moment,
                                        is_training):
        """Builds the moving average update ops when using the moving second moment.

    Args:
      mean: The mean value to update with.
      second_moment: The second_moment value to update with.
      is_training: Boolean Tensor to indicate if we're currently in
        training mode.
    """
        def build_update_ops():
            """Builds the exponential moving average update ops."""

            update_mean_op = moving_averages.assign_moving_average(
                variable=self._moving_mean,
                value=mean,
                decay=self._decay_rate,
                name="update_moving_mean").op

            update_second_moment_op = moving_averages.assign_moving_average(
                variable=self._moving_second_moment,
                value=second_moment,
                decay=self._decay_rate,
                name="update_moving_second_moment").op

            return update_mean_op, update_second_moment_op

        def build_no_ops():
            return (tf.no_op(), tf.no_op())

        # Only make the ops if we know that `is_training=True`, or the value of
        # `is_training` is unknown.
        is_training_const = utils.constant_value(is_training)
        if is_training_const is None or is_training_const:
            update_mean_op, update_second_moment_op = utils.smart_cond(
                is_training,
                build_update_ops,
                build_no_ops,
            )

            # Every new connection creates a new op which adds its contribution
            # to the running average when ran.
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean_op)
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS,
                                 update_second_moment_op)
Example #11
0
def dropout(inputs,
            keep_prob=0.5,
            noise_shape=None,
            is_training=True,
            outputs_collections=None,
            scope=None):
    """Returns a dropout op applied to the input.

  With probability `keep_prob`, outputs the input element scaled up by
  `1 / keep_prob`, otherwise outputs `0`.  The scaling is so that the expected
  sum is unchanged.

  Args:
    inputs: the tensor to pass to the nn.dropout op.
    keep_prob: A scalar `Tensor` with the same type as x. The probability
      that each element is kept.
    noise_shape: A 1-D `Tensor` of type `int32`, representing the
      shape for randomly generated keep/drop flags.
    is_training: A bool `Tensor` indicating whether or not the model
      is in training mode. If so, dropout is applied and values scaled.
      Otherwise, inputs is returned.
    outputs_collections: collection to add the outputs.
    scope: Optional scope for op_scope.

  Returns:
    a tensor representing the output of the operation.
  """
    with ops.op_scope([inputs], scope, 'Dropout') as sc:
        inputs = ops.convert_to_tensor(inputs)
        is_training_value = utils.constant_value(is_training, dtypes.bool)
        if is_training_value is not None:
            if is_training_value:
                outputs = nn.dropout(inputs, keep_prob, noise_shape)
            else:
                outputs = inputs
        else:

            def _dropout():
                return nn.dropout(inputs, keep_prob, noise_shape)

            outputs = control_flow_ops.cond(is_training, _dropout,
                                            lambda: inputs)
        return utils.collect_named_outputs(outputs_collections, sc, outputs)
Example #12
0
def dropout(inputs,
            keep_prob=0.5,
            noise_shape=None,
            is_training=True,
            outputs_collections=None,
            scope=None):
  """Returns a dropout op applied to the input.

  With probability `keep_prob`, outputs the input element scaled up by
  `1 / keep_prob`, otherwise outputs `0`.  The scaling is so that the expected
  sum is unchanged.

  Args:
    inputs: the tensor to pass to the nn.dropout op.
    keep_prob: A scalar `Tensor` with the same type as x. The probability
      that each element is kept.
    noise_shape: A 1-D `Tensor` of type `int32`, representing the
      shape for randomly generated keep/drop flags.
    is_training: A bool `Tensor` indicating whether or not the model
      is in training mode. If so, dropout is applied and values scaled.
      Otherwise, inputs is returned.
    outputs_collections: collection to add the outputs.
    scope: Optional scope for op_scope.

  Returns:
    a tensor representing the output of the operation.
  """
  with ops.op_scope([inputs], scope, 'Dropout') as sc:
    inputs = ops.convert_to_tensor(inputs)
    is_training_value = utils.constant_value(is_training, dtypes.bool)
    if is_training_value is not None:
      if is_training_value:
        outputs = nn.dropout(inputs, keep_prob, noise_shape)
      else:
        outputs = inputs
    else:
      def _dropout():
        return nn.dropout(inputs, keep_prob, noise_shape)
      outputs = control_flow_ops.cond(is_training,
                                      _dropout,
                                      lambda: inputs)
    return utils.collect_named_outputs(outputs_collections, sc, outputs)
Example #13
0
 def test_value(self):
     for v in [True, False, 1, 0, 1.0]:
         value = utils.constant_value(v)
         self.assertEqual(value, v)
Example #14
0
def nan_batch_norm(inputs, decay=0.999, center=True, scale=False, epsilon=0.001,
        is_training=True, reuse=None, variables_collections=None, outputs_collections=None,
        trainable=False, scope=None):
    with variable_scope.variable_op_scope([inputs],
                    scope, 'NanBatchNorm', reuse=reuse) as sc:
        inputs_shape = inputs.get_shape()
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
          raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        dtype = inputs.dtype.base_dtype
        axis = list(range(inputs_rank - 1))
        params_shape = inputs_shape[-1:]
        beta, gamma = None, None
        if center:
          beta_collections = utils.get_variable_collections(variables_collections,
                                                            'beta')
          beta = variables.model_variable('beta',
                                          shape=params_shape,
                                          dtype=dtype,
                                          initializer=init_ops.zeros_initializer,
                                          collections=beta_collections,
                                          trainable=False)
        if scale:
          gamma_collections = utils.get_variable_collections(variables_collections,
                                                             'gamma')
          gamma = variables.model_variable('gamma',
                                           shape=params_shape,
                                           dtype=dtype,
                                           initializer=init_ops.ones_initializer,
                                           collections=gamma_collections,
                                           trainable=trainable)
        # Create moving_mean and moving_variance variables and add them to the
        # appropiate collections.
        moving_mean_collections = utils.get_variable_collections(
            variables_collections, 'moving_mean')
        moving_mean = variables.model_variable(
            'moving_mean',
            shape=params_shape,
            dtype=dtype,
            initializer=init_ops.zeros_initializer,
            trainable=False,
            collections=moving_mean_collections)
        moving_variance_collections = utils.get_variable_collections(
            variables_collections, 'moving_variance')
        moving_variance = variables.model_variable(
            'moving_variance',
            shape=params_shape,
            dtype=dtype,
            initializer=init_ops.ones_initializer,
            trainable=False,
            collections=moving_variance_collections)
        is_training_value = utils.constant_value(is_training)
        need_moments = is_training_value is None or is_training_value
        if need_moments:
            mean = nanmean(inputs, axis=axis)
            variance = nanvar(inputs, axis=axis)
            moving_mean = moving_averages.assign_moving_average(
                moving_mean, mean, decay)
            moving_variance = moving_averages.assign_moving_average(
                moving_variance, variance, decay)
        mean, variance = moving_mean, moving_variance
        outputs = tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
        outputs.set_shape(inputs_shape)
        return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
Example #15
0
def batch_norm(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               activation_fn=None,
               initializers={},
               updates_collections=None,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=False,
               scope=None):
    """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.

    "Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift"

    Sergey Ioffe, Christian Szegedy

  Can be used as a normalizer function for conv2d and fully_connected.

  Note: When is_training is True the moving_mean and moving_variance need to be
  updated, by default the update_ops are placed in tf.GraphKeys.UPDATE_OPS so
  they need to be added as a dependency to the train_op, example:

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
      updates = tf.group(*update_ops)
      total_loss = control_flow_ops.with_dependencies([updates], total_loss)

  One can set update_collections=None to force the updates in place, but that
  can have speed penalty, specially in distributed settings.

  Args:
    inputs: a tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension.
    decay: decay for the moving average.
    center: If True, subtract `beta`. If False, `beta` is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: small float added to variance to avoid dividing by zero.
    activation_fn: activation function, default set to None to skip it and
      maintain a linear activation.
    updates_collections: collections to collect the update ops for computation.
      The updates_ops need to be excuted with the train_op.
      If None, a control dependency would be added to make sure the updates are
      computed in place.
    is_training: whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.
    reuse: whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: optional collections for the variables.
    outputs_collections: collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
    scope: Optional scope for `variable_scope`.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: if rank or last dimension of `inputs` is undefined.
  """

    with variable_scope.variable_scope(scope,
                                       'BatchNorm', [inputs],
                                       reuse=reuse) as sc:
        inputs = ops.convert_to_tensor(inputs)
        inputs_shape = inputs.get_shape()
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
            raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        dtype = inputs.dtype.base_dtype
        axis = list(range(inputs_rank - 1))
        params_shape = inputs_shape[-1:]
        if not params_shape.is_fully_defined():
            raise ValueError('Inputs %s has undefined last dimension %s.' %
                             (inputs.name, params_shape))
        # Allocate parameters for the beta and gamma of the normalization.
        beta, gamma = None, None

        # Create moving_mean and moving_variance variables and add them to the
        # appropiate collections.
        moving_mean_initializer = initializers.get('moving_mean',
                                                   init_ops.zeros_initializer)
        moving_mean = variables.model_variable(
            'moving_mean',
            shape=params_shape,
            dtype=dtype,
            initializer=moving_mean_initializer,
            trainable=False)
        moving_variance_initializer = initializers.get(
            'moving_variance', init_ops.ones_initializer)
        moving_variance = variables.model_variable(
            'moving_variance',
            shape=params_shape,
            dtype=dtype,
            initializer=moving_variance_initializer,
            trainable=False)

        # If `is_training` doesn't have a constant value, because it is a `Tensor`,
        # a `Variable` or `Placeholder` then is_training_value will be None and
        # `needs_moments` will be true.
        is_training_value = utils.constant_value(is_training)
        need_moments = is_training_value is None or is_training_value
        if need_moments:
            # Calculate the moments based on the individual batch.
            # Use a copy of moving_mean as a shift to compute more reliable moments.
            shift = math_ops.add(moving_mean, 0)
            mean, variance = nn.moments(inputs, axis, shift=shift)
            moving_vars_fn = lambda: (moving_mean, moving_variance)
            if updates_collections is None:

                def _force_updates():
                    """Internal function forces updates moving_vars if is_training."""
                    update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean, mean, decay)
                    update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance, variance, decay)
                    with ops.control_dependencies(
                        [update_moving_mean, update_moving_variance]):
                        return array_ops.identity(mean), array_ops.identity(
                            variance)

                mean, variance = utils.smart_cond(is_training, _force_updates,
                                                  moving_vars_fn)
            else:

                def _delay_updates():
                    """Internal function that delay updates moving_vars if is_training."""
                    update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean, mean, decay)
                    update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance, variance, decay)
                    return update_moving_mean, update_moving_variance

                update_mean, update_variance = utils.smart_cond(
                    is_training, _delay_updates, moving_vars_fn)
                ops.add_to_collections(updates_collections, update_mean)
                ops.add_to_collections(updates_collections, update_variance)
                # Use computed moments during training and moving_vars otherwise.
                vars_fn = lambda: (mean, variance)
                mean, variance = utils.smart_cond(is_training, vars_fn,
                                                  moving_vars_fn)
        else:
            mean, variance = moving_mean, moving_variance
        # Compute batch_normalization.
        outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma,
                                         epsilon)
        outputs.set_shape(inputs_shape)
        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return utils.collect_named_outputs(outputs_collections,
                                           sc.original_name_scope, outputs)
Example #16
0
def batch_norm(
        inputs,
        decay=0.999,
        center=True,
        scale=False,
        epsilon=0.001,
        activation_fn=None,
        param_initializers=None,
        updates_collections=ops.GraphKeys.UPDATE_OPS,
        is_training=True,
        reuse=None,
        variables_collections=None,
        outputs_collections=None,
        trainable=True,
        batch_weights=None,
        fused=False,
        #data_format=DATA_FORMAT_NHWC,
        scope=None):
    """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
    "Batch Normalization: Accelerating Deep Network Training by Reducing
    Internal Covariate Shift"
    Sergey Ioffe, Christian Szegedy
  Can be used as a normalizer function for conv2d and fully_connected.
  Note: When is_training is True the moving_mean and moving_variance need to be
  updated, by default the update_ops are placed in `tf.GraphKeys.UPDATE_OPS` so
  they need to be added as a dependency to the `train_op`, example:
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
      updates = tf.group(*update_ops)
      total_loss = control_flow_ops.with_dependencies([updates], total_loss)
  One can set updates_collections=None to force the updates in place, but that
  can have speed penalty, specially in distributed settings.
  Args:
    inputs: a tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC` and the second dimension if `data_format` is
      `NCHW`.
    decay: decay for the moving average.
    center: If True, subtract `beta`. If False, `beta` is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: small float added to variance to avoid dividing by zero.
    activation_fn: activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: optional initializers for beta, gamma, moving mean and
      moving variance.
    updates_collections: collections to collect the update ops for computation.
      The updates_ops need to be executed with the train_op.
      If None, a control dependency would be added to make sure the updates are
      computed in place.
    is_training: whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.
    reuse: whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: optional collections for the variables.
    outputs_collections: collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    batch_weights: An optional tensor of shape `[batch_size]`,
      containing a frequency weight for each batch item. If present,
      then the batch normalization uses weighted mean and
      variance. (This can be used to correct for bias in training
      example selection.)
    fused:  Use nn.fused_batch_norm if True, nn.batch_normalization otherwise.
    data_format: A string. `NHWC` (default) and `NCHW` are supported.
    scope: Optional scope for `variable_scope`.
  Returns:
    A `Tensor` representing the output of the operation.
  Raises:
    ValueError: if `batch_weights` is not None and `fused` is True.
    ValueError: if `data_format` is neither `NHWC` nor `NCHW`.
    ValueError: if `data_format` is `NCHW` while `fused` is False.
    ValueError: if the rank of `inputs` is undefined.
    ValueError: if rank or last dimension of `inputs` is undefined.
  """
    if fused:
        if batch_weights is not None:
            raise ValueError('Weighted mean and variance is not currently '
                             'supported for fused batch norm.')
        return _fused_batch_norm(inputs,
                                 decay=decay,
                                 center=center,
                                 scale=scale,
                                 epsilon=epsilon,
                                 activation_fn=activation_fn,
                                 param_initializers=param_initializers,
                                 updates_collections=updates_collections,
                                 is_training=is_training,
                                 reuse=reuse,
                                 variables_collections=variables_collections,
                                 outputs_collections=outputs_collections,
                                 trainable=trainable,
                                 data_format=data_format,
                                 scope=scope)

    #if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
    #raise ValueError('data_format has to be either NCHW or NHWC.')
    #if data_format == DATA_FORMAT_NCHW:
    #raise ValueError('data_format must be NHWC if fused is False.')

    with variable_scope.variable_scope(scope,
                                       'BatchNorm', [inputs],
                                       reuse=reuse) as sc:
        inputs = ops.convert_to_tensor(inputs)
        inputs_shape = inputs.get_shape()
        inputs_rank = inputs_shape.ndims
        if inputs_rank is None:
            raise ValueError('Inputs %s has undefined rank.' % inputs.name)
        dtype = inputs.dtype.base_dtype
        if batch_weights is not None:
            batch_weights = ops.convert_to_tensor(batch_weights)
            inputs_shape[0:1].assert_is_compatible_with(
                batch_weights.get_shape())
            # Reshape batch weight values so they broadcast across inputs.
            nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
            batch_weights = array_ops.reshape(batch_weights, nshape)
        axis = list(range(inputs_rank - 1))
        params_shape = inputs_shape[-1:]
        if not params_shape.is_fully_defined():
            raise ValueError('Inputs %s has undefined last dimension %s.' %
                             (inputs.name, params_shape))

        # Allocate parameters for the beta and gamma of the normalization.
        beta, gamma = None, None
        if not param_initializers:
            param_initializers = {}
        if center:
            beta_collections = utils.get_variable_collections(
                variables_collections, 'beta')
            beta_initializer = param_initializers.get(
                'beta', init_ops.zeros_initializer)
            beta = variables.model_variable('beta',
                                            shape=params_shape,
                                            dtype=dtype,
                                            initializer=beta_initializer,
                                            collections=beta_collections,
                                            trainable=trainable)
        if scale:
            gamma_collections = utils.get_variable_collections(
                variables_collections, 'gamma')
            gamma_initializer = param_initializers.get(
                'gamma', init_ops.ones_initializer)
            gamma = variables.model_variable('gamma',
                                             shape=params_shape,
                                             dtype=dtype,
                                             initializer=gamma_initializer,
                                             collections=gamma_collections,
                                             trainable=trainable)

        # Create moving_mean and moving_variance variables and add them to the
        # appropiate collections. We disable variable partitioning while creating
        # them, because assign_moving_average is not yet supported for partitioned
        # variables.
        partitioner = variable_scope.get_variable_scope().partitioner
        try:
            variable_scope.get_variable_scope().set_partitioner(None)
            moving_mean_collections = utils.get_variable_collections(
                variables_collections, 'moving_mean')
            moving_mean_initializer = param_initializers.get(
                'moving_mean', init_ops.zeros_initializer)
            moving_mean = variables.model_variable(
                'moving_mean',
                shape=params_shape,
                dtype=dtype,
                initializer=moving_mean_initializer,
                trainable=False,
                collections=moving_mean_collections)
            moving_variance_collections = utils.get_variable_collections(
                variables_collections, 'moving_variance')
            moving_variance_initializer = param_initializers.get(
                'moving_variance', init_ops.ones_initializer)
            moving_variance = variables.model_variable(
                'moving_variance',
                shape=params_shape,
                dtype=dtype,
                initializer=moving_variance_initializer,
                trainable=False,
                collections=moving_variance_collections)
        finally:
            variable_scope.get_variable_scope().set_partitioner(partitioner)

        # If `is_training` doesn't have a constant value, because it is a `Tensor`,
        # a `Variable` or `Placeholder` then is_training_value will be None and
        # `needs_moments` will be true.
        is_training_value = utils.constant_value(is_training)
        need_moments = is_training_value is None or is_training_value
        if need_moments:
            # Calculate the moments based on the individual batch.
            if batch_weights is None:
                # Use a copy of moving_mean as a shift to compute more reliable moments.
                shift = math_ops.add(moving_mean, 0)
                mean, variance = nn.moments(inputs, axis, shift=shift)
            else:
                mean, variance = nn.weighted_moments(inputs, axis,
                                                     batch_weights)

            moving_vars_fn = lambda: (moving_mean, moving_variance)
            if updates_collections is None:

                def _force_updates():
                    """Internal function forces updates moving_vars if is_training."""
                    update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean, mean, decay)
                    update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance, variance, decay)
                    with ops.control_dependencies(
                        [update_moving_mean, update_moving_variance]):
                        return array_ops.identity(mean), array_ops.identity(
                            variance)

                mean, variance = utils.smart_cond(is_training, _force_updates,
                                                  moving_vars_fn)
            else:

                def _delay_updates():
                    """Internal function that delay updates moving_vars if is_training."""
                    update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean, mean, decay, zero_debias=False)
                    update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance, variance, decay, zero_debias=False)
                    return update_moving_mean, update_moving_variance

                update_mean, update_variance = utils.smart_cond(
                    is_training, _delay_updates, moving_vars_fn)
                ops.add_to_collections(updates_collections, update_mean)
                ops.add_to_collections(updates_collections, update_variance)
                # Use computed moments during training and moving_vars otherwise.
                vars_fn = lambda: (mean, variance)
                mean, variance = utils.smart_cond(is_training, vars_fn,
                                                  moving_vars_fn)
        else:
            mean, variance = moving_mean, moving_variance
        # Compute batch_normalization.
        # Print out offset, scale, mean, variance
        import tensorflow as tf
        print_op_gamma = tf.Print(gamma, [gamma], message="scale factor is: ")
        print_op_beta = tf.Print(beta, [beta], message="offset factor is: ")
        print_op_mean = tf.Print(mean, [mean], message="mean is: ")
        print_op_var = tf.Print(variance, [variance], message="variance is: ")
        with ops.control_dependencies(
            [print_op_gamma, print_op_beta, print_op_mean, print_op_var]):
            outputs = nn.batch_normalization(inputs, mean, variance, beta,
                                             gamma, epsilon)
        outputs.set_shape(inputs_shape)
        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return utils.collect_named_outputs(outputs_collections,
                                           sc.original_name_scope, outputs)
Example #17
0
    def batch_norm_backbone(inputs,
                            decay=0.999,
                            center=True,
                            scale=False,
                            epsilon=0.001,
                            activation_fn=None,
                            param_initializers=None,
                            param_regularizers=None,
                            updates_collections=ops.GraphKeys.UPDATE_OPS,
                            is_training=True,
                            reuse=None,
                            variables_collections=None,
                            outputs_collections=None,
                            trainable=True,
                            batch_weights=None,
                            fused=None,
                            data_format=DATA_FORMAT_NHWC,
                            zero_debias_moving_mean=False,
                            scope=None,
                            renorm=False,
                            renorm_clipping=None,
                            renorm_decay=0.99,
                            adjustment=None,
                            tower_config=None):

        """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
          "Batch Normalization: Accelerating Deep Network Training by Reducing
          Internal Covariate Shift"
          Sergey Ioffe, Christian Szegedy
        Can be used as a normalizer function for conv2d and fully_connected. The
        normalization is over all but the last dimension if `data_format` is `NHWC`
        and all but the second dimension if `data_format` is `NCHW`.  In case of a 2D
        tensor this corresponds to the batch dimension, while in case of a 4D tensor
        this
        corresponds to the batch and space dimensions.
        Note: when training, the moving_mean and moving_variance need to be updated.
        By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they
        need to be added as a dependency to the `train_op`. For example:
        ```python
          update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
          with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss)
        ```
        One can set updates_collections=None to force the updates in place, but that
        can have a speed penalty, especially in distributed settings.
        Args:
          inputs: A tensor with 2 or more dimensions, where the first dimension has
            `batch_size`. The normalization is over all but the last dimension if
            `data_format` is `NHWC` and the second dimension if `data_format` is
            `NCHW`.
          decay: Decay for the moving average. Reasonable values for `decay` are close
            to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
            Lower `decay` value (recommend trying `decay`=0.9) if model experiences
            reasonably good training performance but poor validation and/or test
            performance. Try zero_debias_moving_mean=True for improved stability.
          center: If True, add offset of `beta` to normalized tensor. If False, `beta`
            is ignored.
          scale: If True, multiply by `gamma`. If False, `gamma` is
            not used. When the next layer is linear (also e.g. `nn.relu`), this can be
            disabled since the scaling can be done by the next layer.
          epsilon: Small float added to variance to avoid dividing by zero.
          activation_fn: Activation function, default set to None to skip it and
            maintain a linear activation.
          param_initializers: Optional initializers for beta, gamma, moving mean and
            moving variance.
          param_regularizers: Optional regularizer for beta and gamma.
          updates_collections: Collections to collect the update ops for computation.
            The updates_ops need to be executed with the train_op.
            If None, a control dependency would be added to make sure the updates are
            computed in place.
          is_training: Whether or not the layer is in training mode. In training mode
            it would accumulate the statistics of the moments into `moving_mean` and
            `moving_variance` using an exponential moving average with the given
            `decay`. When it is not in training mode then it would use the values of
            the `moving_mean` and the `moving_variance`.
          reuse: Whether or not the layer and its variables should be reused. To be
            able to reuse the layer scope must be given.
          variables_collections: Optional collections for the variables.
          outputs_collections: Collections to add the outputs.
          trainable: If `True` also add variables to the graph collection
            `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
          batch_weights: An optional tensor of shape `[batch_size]`,
            containing a frequency weight for each batch item. If present,
            then the batch normalization uses weighted mean and
            variance. (This can be used to correct for bias in training
            example selection.)
          fused: if `None` or `True`, use a faster, fused implementation if possible.
            If `False`, use the system recommended implementation.
          data_format: A string. `NHWC` (default) and `NCHW` are supported.
          zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
            pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
          scope: Optional scope for `variable_scope`.
          renorm: Whether to use Batch Renormalization
            (https://arxiv.org/abs/1702.03275). This adds extra variables during
            training. The inference is the same for either value of this parameter.
          renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
            scalar `Tensors` used to clip the renorm correction. The correction
            `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
            `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
            dmax are set to inf, 0, inf, respectively.
          renorm_decay: Momentum used to update the moving means and standard
            deviations with renorm. Unlike `momentum`, this affects training
            and should be neither too small (which would add noise) nor too large
            (which would give stale estimates). Note that `decay` is still applied
            to get the means and variances for inference.
          adjustment: A function taking the `Tensor` containing the (dynamic) shape of
            the input tensor and returning a pair (scale, bias) to apply to the
            normalized values (before gamma and beta), only during training. For
            example,
              `adjustment = lambda shape: (
                tf.random_uniform(shape[-1:], 0.93, 1.07),
                tf.random_uniform(shape[-1:], -0.1, 0.1))`
            will scale the normalized value by up to 7% up or down, then shift the
            result by up to 0.1 (with independent scaling and bias for each feature
            but shared across all examples), and finally apply gamma and/or beta. If
            `None`, no adjustment is applied.
        Returns:
          A `Tensor` representing the output of the operation.
        Raises:
          ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
          ValueError: If the rank of `inputs` is undefined.
          ValueError: If rank or channels dimension of `inputs` is undefined.
        """
        # if fused is None:
        #     fused = True

        # Only use _fused_batch_norm if all of the following three
        # conditions are true:
        # (1) fused is set True;
        # (2) it is possible to use (currently it doesn't support batch weights,
        #   renorm, and the case when rank is neither 2 nor 4);
        # (3) it is used with zero_debias_moving_mean, or an input shape of rank 2,
        #   or non-default updates_collections (not implemented in
        #   normalization_layers.BatchNormalization yet); otherwise use the fused
        #   implementation in normalization_layers.BatchNormalization.
        # inputs = ops.convert_to_tensor(inputs)
        # rank = inputs.get_shape().ndims
        # possible_to_fuse = (
        #     batch_weights is None and not renorm and rank in [2, 4] and
        #     adjustment is None)
        # if fused and possible_to_fuse and (
        #                 zero_debias_moving_mean or rank == 2 or
        #                 updates_collections is not ops.GraphKeys.UPDATE_OPS):
        #     return _fused_batch_norm(
        #         inputs,
        #         decay=decay,
        #         center=center,
        #         scale=scale,
        #         epsilon=epsilon,
        #         activation_fn=activation_fn,
        #         param_initializers=param_initializers,
        #         param_regularizers=param_regularizers,
        #         updates_collections=updates_collections,
        #         is_training=is_training,
        #         reuse=reuse,
        #         variables_collections=variables_collections,
        #         outputs_collections=outputs_collections,
        #         trainable=trainable,
        #         data_format=data_format,
        #         zero_debias_moving_mean=zero_debias_moving_mean,
        #         scope=scope)

        if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
            raise ValueError('data_format has to be either NCHW or NHWC.')

        layer_variable_getter = _build_variable_getter()
        with variable_scope.variable_scope(
                scope,
                'BatchNorm', [inputs],
                reuse=reuse,
                custom_getter=layer_variable_getter) as sc:
            inputs = ops.convert_to_tensor(inputs)

            # # Determine whether we can use the core layer class.
            # if (batch_weights is None and
            #             updates_collections is ops.GraphKeys.UPDATE_OPS and
            #         not zero_debias_moving_mean):
            #     print("F**K !!!!")
            #     # Use the core layer class.
            #     axis = 1 if data_format == DATA_FORMAT_NCHW else -1
            #     if not param_initializers:
            #         param_initializers = {}
            #     beta_initializer = param_initializers.get('beta',
            #                                               init_ops.zeros_initializer())
            #     gamma_initializer = param_initializers.get('gamma',
            #                                                init_ops.ones_initializer())
            #     moving_mean_initializer = param_initializers.get(
            #         'moving_mean', init_ops.zeros_initializer())
            #     moving_variance_initializer = param_initializers.get(
            #         'moving_variance', init_ops.ones_initializer())
            #     if not param_regularizers:
            #         param_regularizers = {}
            #     beta_regularizer = param_regularizers.get('beta')
            #     gamma_regularizer = param_regularizers.get('gamma')
            #     layer = normalization_layers.BatchNormalization(
            #         axis=axis,
            #         momentum=decay,
            #         epsilon=epsilon,
            #         center=center,
            #         scale=scale,
            #         beta_initializer=beta_initializer,
            #         gamma_initializer=gamma_initializer,
            #         moving_mean_initializer=moving_mean_initializer,
            #         moving_variance_initializer=moving_variance_initializer,
            #         beta_regularizer=beta_regularizer,
            #         gamma_regularizer=gamma_regularizer,
            #         trainable=trainable,
            #         renorm=renorm,
            #         renorm_clipping=renorm_clipping,
            #         renorm_momentum=renorm_decay,
            #         adjustment=adjustment,
            #         name=sc.name,
            #         _scope=sc,
            #         _reuse=reuse,
            #         fused=fused)
            #     outputs = layer.apply(inputs, training=is_training)
            #
            #     # Add variables to collections.
            #     _add_variable_to_collections(layer.moving_mean, variables_collections,
            #                                  'moving_mean')
            #     _add_variable_to_collections(layer.moving_variance, variables_collections,
            #                                  'moving_variance')
            #     if layer.beta is not None:
            #         _add_variable_to_collections(layer.beta, variables_collections, 'beta')
            #     if layer.gamma is not None:
            #         _add_variable_to_collections(layer.gamma, variables_collections,
            #                                      'gamma')
            #
            #     if activation_fn is not None:
            #         outputs = activation_fn(outputs)
            #     return utils.collect_named_outputs(outputs_collections, sc.name, outputs)

            # Not supported by layer class: batch_weights argument,
            # and custom updates_collections. In that case, use the legacy BN
            # implementation.
            # Custom updates collections are not supported because the update logic
            # is different in this case, in particular w.r.t. "forced updates" and
            # update op reuse.
            if renorm:
                raise ValueError('renorm is not supported with batch_weights, '
                                 'updates_collections or zero_debias_moving_mean')
            inputs_shape = inputs.get_shape()
            inputs_rank = inputs_shape.ndims
            if inputs_rank is None:
                raise ValueError('Inputs %s has undefined rank.' % inputs.name)
            dtype = inputs.dtype.base_dtype
            if batch_weights is not None:
                batch_weights = ops.convert_to_tensor(batch_weights)
                inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
                # Reshape batch weight values so they broadcast across inputs.
                nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
                batch_weights = array_ops.reshape(batch_weights, nshape)

            if data_format == DATA_FORMAT_NCHW:
                moments_axes = [0] + list(range(2, inputs_rank))
                params_shape = inputs_shape[1:2]
                # For NCHW format, rather than relying on implicit broadcasting, we
                # explicitly reshape the params to params_shape_broadcast when computing
                # the moments and the batch normalization.
                params_shape_broadcast = list(
                    [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
            else:
                moments_axes = list(range(inputs_rank - 1))
                params_shape = inputs_shape[-1:]
                params_shape_broadcast = None
            if not params_shape.is_fully_defined():
                raise ValueError('Inputs %s has undefined channels dimension %s.' %
                                 (inputs.name, params_shape))

            # Allocate parameters for the beta and gamma of the normalization.
            beta, gamma = None, None
            if not param_initializers:
                param_initializers = {}
            if center:
                beta_collections = utils.get_variable_collections(variables_collections,
                                                                  'beta')
                beta_initializer = param_initializers.get('beta',
                                                          init_ops.zeros_initializer())
                beta = variables.model_variable(
                    'beta',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=beta_initializer,
                    collections=beta_collections,
                    trainable=trainable)
            if scale:
                gamma_collections = utils.get_variable_collections(
                    variables_collections, 'gamma')
                gamma_initializer = param_initializers.get('gamma',
                                                           init_ops.ones_initializer())
                gamma = variables.model_variable(
                    'gamma',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=gamma_initializer,
                    collections=gamma_collections,
                    trainable=trainable)

            # Create moving_mean and moving_variance variables and add them to the
            # appropriate collections. We disable variable partitioning while creating
            # them, because assign_moving_average is not yet supported for partitioned
            # variables (this needs to be handled carefully, as it may break
            # the checkpoint backward compatibility).
            with variable_scope.variable_scope(
                    variable_scope.get_variable_scope()) as local_scope:
                local_scope.set_partitioner(None)
                moving_mean_collections = utils.get_variable_collections(
                    variables_collections, 'moving_mean')
                moving_mean_initializer = param_initializers.get(
                    'moving_mean', init_ops.zeros_initializer())
                moving_mean = variables.model_variable(
                    'moving_mean',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=moving_mean_initializer,
                    trainable=False,
                    collections=moving_mean_collections)
                moving_variance_collections = utils.get_variable_collections(
                    variables_collections, 'moving_variance')
                moving_variance_initializer = param_initializers.get(
                    'moving_variance', init_ops.ones_initializer())
                moving_variance = variables.model_variable(
                    'moving_variance',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=moving_variance_initializer,
                    trainable=False,
                    collections=moving_variance_collections)

            # If `is_training` doesn't have a constant value, because it is a `Tensor`,
            # a `Variable` or `Placeholder` then is_training_value will be None and
            # `needs_moments` will be true.
            is_training_value = utils.constant_value(is_training)
            need_moments = is_training_value is None or is_training_value
            if need_moments:
                # Calculate the moments based on the individual batch.
                if batch_weights is None:
                    if data_format == DATA_FORMAT_NCHW:
                        mean, variance = moments(inputs, moments_axes, tower_config=tower_config, keep_dims=True)
                        mean = array_ops.reshape(mean, [-1])
                        variance = array_ops.reshape(variance, [-1])
                    else:
                        mean, variance = moments(inputs, moments_axes, tower_config=tower_config)
                else:
                    if data_format == DATA_FORMAT_NCHW:
                        mean, variance = weighted_moments(
                            inputs, moments_axes, batch_weights, tower_config, keep_dims=True)
                        mean = array_ops.reshape(mean, [-1])
                        variance = array_ops.reshape(variance, [-1])
                    else:
                        mean, variance = weighted_moments(inputs, moments_axes,
                                                             batch_weights, tower_config=tower_config)

                moving_vars_fn = lambda: (moving_mean, moving_variance)
                if updates_collections is None:

                    def _force_updates():
                        """Internal function forces updates moving_vars if is_training."""
                        update_moving_mean = moving_averages.assign_moving_average(
                            moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
                        update_moving_variance = moving_averages.assign_moving_average(
                            moving_variance, variance, decay, zero_debias=False)
                        with ops.control_dependencies(
                                [update_moving_mean, update_moving_variance]):
                            return array_ops.identity(mean), array_ops.identity(variance)

                    mean, variance = utils.smart_cond(is_training, _force_updates,
                                                      moving_vars_fn)
                else:

                    def _delay_updates():
                        """Internal function that delay updates moving_vars if is_training."""
                        update_moving_mean = moving_averages.assign_moving_average(
                            moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
                        update_moving_variance = moving_averages.assign_moving_average(
                            moving_variance, variance, decay, zero_debias=False)
                        return update_moving_mean, update_moving_variance

                    update_mean, update_variance = utils.smart_cond(
                        is_training, _delay_updates, moving_vars_fn)
                    ops.add_to_collections(updates_collections, update_mean)
                    ops.add_to_collections(updates_collections, update_variance)
                    # Use computed moments during training and moving_vars otherwise.
                    vars_fn = lambda: (mean, variance)
                    mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
            else:
                mean, variance = moving_mean, moving_variance
            if data_format == DATA_FORMAT_NCHW:
                mean = array_ops.reshape(mean, params_shape_broadcast)
                variance = array_ops.reshape(variance, params_shape_broadcast)
                if beta is not None:
                    beta = array_ops.reshape(beta, params_shape_broadcast)
                if gamma is not None:
                    gamma = array_ops.reshape(gamma, params_shape_broadcast)

            # Compute batch_normalization.
            outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma,
                                             epsilon)
            outputs.set_shape(inputs_shape)
            if activation_fn is not None:
                outputs = activation_fn(outputs)
            return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
Example #18
0
def batch_norm_mine_old(inputs,
               decay=0.999,
               center=True,
               scale=False,
               epsilon=0.001,
               activation_fn=None,
               param_initializers=None,
               param_regularizers=None,
               updates_collections=ops.GraphKeys.UPDATE_OPS,
               is_training=True,
               reuse=None,
               variables_collections=None,
               outputs_collections=None,
               trainable=True,
               batch_weights=None,
               fused=False,
               data_format=DATA_FORMAT_NHWC,
               zero_debias_moving_mean=False,
               scope=None,
               renorm=False,
               renorm_clipping=None,
               renorm_decay=0.99):
  """
  This earlier version of my modification to batch norm uses
current_mean and current_variance if is_training is True and
moving_mean and moving_variance otherwise. This was leading a large divergence between
the results depending upon whether the is_training set to True or not.

I think ideally it should always use moving_mean and moving_variance. batch_norm_mine
does this.

  Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.
copy of tensorflow.contrib.layers
  Args:
    inputs: A tensor with 2 or more dimensions, where the first dimension has
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC` and the second dimension if `data_format` is
      `NCHW`.
    decay: Decay for the moving average. Reasonable values for `decay` are close
      to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
      Lower `decay` value (recommend trying `decay`=0.9) if model experiences
      reasonably good training performance but poor validation and/or test
      performance. Try zero_debias_moving_mean=True for improved stability.
    center: If True, add offset of `beta` to normalized tensor. If False, `beta`
      is ignored.
    scale: If True, multiply by `gamma`. If False, `gamma` is
      not used. When the next layer is linear (also e.g. `nn.relu`), this can be
      disabled since the scaling can be done by the next layer.
    epsilon: Small float added to variance to avoid dividing by zero.
    activation_fn: Activation function, default set to None to skip it and
      maintain a linear activation.
    param_initializers: Optional initializers for beta, gamma, moving mean and
      moving variance.
    param_regularizers: Optional regularizer for beta and gamma.
    updates_collections: Collections to collect the update ops for computation.
      The updates_ops need to be executed with the train_op.
      If None, a control dependency would be added to make sure the updates are
      computed in place.
    is_training: Whether or not the layer is in training mode. In training mode
      it would accumulate the statistics of the moments into `moving_mean` and
      `moving_variance` using an exponential moving average with the given
      `decay`. When it is not in training mode then it would use the values of
      the `moving_mean` and the `moving_variance`.
    reuse: Whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    variables_collections: Optional collections for the variables.
    outputs_collections: Collections to add the outputs.
    trainable: If `True` also add variables to the graph collection
      `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
    batch_weights: An optional tensor of shape `[batch_size]`,
      containing a frequency weight for each batch item. If present,
      then the batch normalization uses weighted mean and
      variance. (This can be used to correct for bias in training
      example selection.)
    fused:  Use nn.fused_batch_norm if True, nn.batch_normalization otherwise.
    data_format: A string. `NHWC` (default) and `NCHW` are supported.
    zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new
      pair of variables 'moving_mean/biased' and 'moving_mean/local_step'.
    scope: Optional scope for `variable_scope`.
    renorm: Whether to use Batch Renormalization
      (https://arxiv.org/abs/1702.03275). This adds extra variables during
      training. The inference is the same for either value of this parameter.
    renorm_clipping: A dictionary that may map keys 'rmax', 'rmin', 'dmax' to
      scalar `Tensors` used to clip the renorm correction. The correction
      `(r, d)` is used as `corrected_value = normalized_value * r + d`, with
      `r` clipped to [rmin, rmax], and `d` to [-dmax, dmax]. Missing rmax, rmin,
      dmax are set to inf, 0, inf, respectively.
    renorm_decay: Momentum used to update the moving means and standard
      deviations with renorm. Unlike `momentum`, this affects training
      and should be neither too small (which would add noise) nor too large
      (which would give stale estimates). Note that `decay` is still applied
      to get the means and variances for inference.

  Returns:
    A `Tensor` representing the output of the operation.

  Raises:
    ValueError: If `batch_weights` is not None and `fused` is True.
    ValueError: If `param_regularizers` is not None and `fused` is True.
    ValueError: If `data_format` is neither `NHWC` nor `NCHW`.
    ValueError: If the rank of `inputs` is undefined.
    ValueError: If rank or channels dimension of `inputs` is undefined.
  """
  if fused:
    if batch_weights is not None:
      raise ValueError('Weighted mean and variance is not currently '
                       'supported for fused batch norm.')
    if param_regularizers is not None:
      raise ValueError('Regularizers are not currently '
                       'supported for fused batch norm.')
    if renorm:
      raise ValueError('Renorm is not supported for fused batch norm.')
    return _fused_batch_norm(
        inputs,
        decay=decay,
        center=center,
        scale=scale,
        epsilon=epsilon,
        activation_fn=activation_fn,
        param_initializers=param_initializers,
        updates_collections=updates_collections,
        is_training=is_training,
        reuse=reuse,
        variables_collections=variables_collections,
        outputs_collections=outputs_collections,
        trainable=trainable,
        data_format=data_format,
        zero_debias_moving_mean=zero_debias_moving_mean,
        scope=scope)

  if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
    raise ValueError('data_format has to be either NCHW or NHWC.')

  layer_variable_getter = _build_variable_getter()
  with variable_scope.variable_scope(
      scope, 'BatchNorm', [inputs], reuse=reuse,
      custom_getter=layer_variable_getter) as sc:
    inputs = ops.convert_to_tensor(inputs)

    # Determine whether we can use the core layer class.
    if (batch_weights is None and
        updates_collections is ops.GraphKeys.UPDATE_OPS and
        not zero_debias_moving_mean):
      # Use the core layer class.
      axis = 1 if data_format == DATA_FORMAT_NCHW else -1
      if not param_initializers:
        param_initializers = {}
      beta_initializer = param_initializers.get('beta',
                                                init_ops.zeros_initializer())
      gamma_initializer = param_initializers.get('gamma',
                                                 init_ops.ones_initializer())
      moving_mean_initializer = param_initializers.get(
          'moving_mean', init_ops.zeros_initializer())
      moving_variance_initializer = param_initializers.get(
          'moving_variance', init_ops.ones_initializer())
      if not param_regularizers:
        param_regularizers = {}
      beta_regularizer = param_regularizers.get('beta')
      gamma_regularizer = param_regularizers.get('gamma')
      layer = normalization_layers.BatchNormalization(
          axis=axis,
          momentum=decay,
          epsilon=epsilon,
          center=center,
          scale=scale,
          beta_initializer=beta_initializer,
          gamma_initializer=gamma_initializer,
          moving_mean_initializer=moving_mean_initializer,
          moving_variance_initializer=moving_variance_initializer,
          beta_regularizer=beta_regularizer,
          gamma_regularizer=gamma_regularizer,
          trainable=trainable,
          renorm=renorm,
          renorm_clipping=renorm_clipping,
          renorm_momentum=renorm_decay,
          name=sc.name,
          _scope=sc,
          _reuse=reuse)
      outputs = layer.apply(inputs, training=is_training)

      # Add variables to collections.
      _add_variable_to_collections(
          layer.moving_mean, variables_collections, 'moving_mean')
      _add_variable_to_collections(
          layer.moving_variance, variables_collections, 'moving_variance')
      if layer.beta:
        _add_variable_to_collections(layer.beta, variables_collections, 'beta')
      if layer.gamma:
        _add_variable_to_collections(
            layer.gamma, variables_collections, 'gamma')

      if activation_fn is not None:
        outputs = activation_fn(outputs)
      return utils.collect_named_outputs(outputs_collections,
                                         sc.original_name_scope, outputs)

    # Not supported by layer class: batch_weights argument,
    # and custom updates_collections. In that case, use the legacy BN
    # implementation.
    # Custom updates collections are not supported because the update logic
    # is different in this case, in particular w.r.t. "forced updates" and
    # update op reuse.
    if renorm:
      raise ValueError('renorm is not supported with batch_weights, '
                       'updates_collections or zero_debias_moving_mean')
    inputs_shape = inputs.get_shape()
    inputs_rank = inputs_shape.ndims
    if inputs_rank is None:
      raise ValueError('Inputs %s has undefined rank.' % inputs.name)
    dtype = inputs.dtype.base_dtype
    if batch_weights is not None:
      batch_weights = ops.convert_to_tensor(batch_weights)
      inputs_shape[0:1].assert_is_compatible_with(batch_weights.get_shape())
      # Reshape batch weight values so they broadcast across inputs.
      nshape = [-1] + [1 for _ in range(inputs_rank - 1)]
      batch_weights = array_ops.reshape(batch_weights, nshape)

    if data_format == DATA_FORMAT_NCHW:
      moments_axes = [0] + list(range(2, inputs_rank))
      params_shape = inputs_shape[1:2]
      # For NCHW format, rather than relying on implicit broadcasting, we
      # explicitly reshape the params to params_shape_broadcast when computing
      # the moments and the batch normalization.
      params_shape_broadcast = list(
          [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)])
    else:
      moments_axes = list(range(inputs_rank - 1))
      params_shape = inputs_shape[-1:]
      params_shape_broadcast = None
    if not params_shape.is_fully_defined():
      raise ValueError('Inputs %s has undefined channels dimension %s.' % (
          inputs.name, params_shape))

    # Allocate parameters for the beta and gamma of the normalization.
    beta, gamma = None, None
    if not param_initializers:
      param_initializers = {}
    if center:
      beta_collections = utils.get_variable_collections(variables_collections,
                                                        'beta')
      beta_initializer = param_initializers.get('beta',
                                                init_ops.zeros_initializer())
      beta = variables.model_variable('beta',
                                      shape=params_shape,
                                      dtype=dtype,
                                      initializer=beta_initializer,
                                      collections=beta_collections,
                                      trainable=trainable)
    if scale:
      gamma_collections = utils.get_variable_collections(variables_collections,
                                                         'gamma')
      gamma_initializer = param_initializers.get('gamma',
                                                 init_ops.ones_initializer())
      gamma = variables.model_variable('gamma',
                                       shape=params_shape,
                                       dtype=dtype,
                                       initializer=gamma_initializer,
                                       collections=gamma_collections,
                                       trainable=trainable)

    # Create moving_mean and moving_variance variables and add them to the
    # appropriate collections. We disable variable partitioning while creating
    # them, because assign_moving_average is not yet supported for partitioned
    # variables.
    partitioner = variable_scope.get_variable_scope().partitioner
    try:
      variable_scope.get_variable_scope().set_partitioner(None)
      moving_mean_collections = utils.get_variable_collections(
          variables_collections, 'moving_mean')
      moving_mean_initializer = param_initializers.get(
          'moving_mean', init_ops.zeros_initializer())
      moving_mean = variables.model_variable(
          'moving_mean',
          shape=params_shape,
          dtype=dtype,
          initializer=moving_mean_initializer,
          trainable=False,
          collections=moving_mean_collections)
      moving_variance_collections = utils.get_variable_collections(
          variables_collections, 'moving_variance')
      moving_variance_initializer = param_initializers.get(
          'moving_variance', init_ops.ones_initializer())
      moving_variance = variables.model_variable(
          'moving_variance',
          shape=params_shape,
          dtype=dtype,
          initializer=moving_variance_initializer,
          trainable=False,
          collections=moving_variance_collections)
    finally:
      variable_scope.get_variable_scope().set_partitioner(partitioner)

    # If `is_training` doesn't have a constant value, because it is a `Tensor`,
    # a `Variable` or `Placeholder` then is_training_value will be None and
    # `needs_moments` will be true.
    is_training_value = utils.constant_value(is_training)
    need_moments = is_training_value is None or is_training_value
    if need_moments:
      # Calculate the moments based on the individual batch.
      if batch_weights is None:
        if data_format == DATA_FORMAT_NCHW:
          mean, _ = nn.moments(inputs, moments_axes, keep_dims=True)
          variance,_ = nn.moments( (inputs-moving_mean)**2, moments_axes, keep_dims=True)
          mean = array_ops.reshape(mean, [-1])
          variance = array_ops.reshape(variance, [-1])
        else:
          mean, _ = nn.moments(inputs, moments_axes)
          variance, _ = nn.moments( (inputs-moving_mean)**2, moments_axes)
      else:
        if data_format == DATA_FORMAT_NCHW:
          mean, _ = nn.weighted_moments(inputs, moments_axes,
                                               batch_weights, keep_dims=True)
          variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
                                               batch_weights, keep_dims=True)
          mean = array_ops.reshape(mean, [-1])
          variance = array_ops.reshape(variance, [-1])
        else:
          mean, _ = nn.weighted_moments(inputs, moments_axes,
                                               batch_weights)
          variance, _ = nn.weighted_moments( (inputs-moving_mean)**2, moments_axes,
                                               batch_weights)

      moving_vars_fn = lambda: (moving_mean, moving_variance)
      if updates_collections is None:
        def _force_updates():
          """Internal function forces updates moving_vars if is_training."""
          update_moving_mean = moving_averages.assign_moving_average(
              moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
          update_moving_variance = moving_averages.assign_moving_average(
              moving_variance, variance, decay, zero_debias=False)
          with ops.control_dependencies([update_moving_mean,
                                         update_moving_variance]):
            return array_ops.identity(mean), array_ops.identity(variance)
        mean, variance = utils.smart_cond(is_training,
                                          _force_updates,
                                          moving_vars_fn)
      else:
        def _delay_updates():
          """Internal function that delay updates moving_vars if is_training."""
          update_moving_mean = moving_averages.assign_moving_average(
              moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
          update_moving_variance = moving_averages.assign_moving_average(
              moving_variance, variance, decay, zero_debias=False)
          return update_moving_mean, update_moving_variance

        update_mean, update_variance = utils.smart_cond(is_training,
                                                        _delay_updates,
                                                        moving_vars_fn)
        ops.add_to_collections(updates_collections, update_mean)
        ops.add_to_collections(updates_collections, update_variance)
        # Use computed moments during training and moving_vars otherwise.
        vars_fn = lambda: (mean, variance)
        mean, variance = utils.smart_cond(is_training, vars_fn, moving_vars_fn)
    else:
      mean, variance = moving_mean, moving_variance
    if data_format == DATA_FORMAT_NCHW:
      mean = array_ops.reshape(mean, params_shape_broadcast)
      variance = array_ops.reshape(variance, params_shape_broadcast)
      beta = array_ops.reshape(beta, params_shape_broadcast)
      if gamma is not None:
        gamma = array_ops.reshape(gamma, params_shape_broadcast)

    # Compute batch_normalization.
    outputs = nn.batch_normalization(inputs, mean, variance, beta, gamma,
                                     epsilon)
    outputs.set_shape(inputs_shape)
    if activation_fn is not None:
      outputs = activation_fn(outputs)
    return utils.collect_named_outputs(outputs_collections,
                                       sc.original_name_scope, outputs)
Example #19
0
def fused_batch_norm(
        inputs,
        renorm=False,
        RMAX=None,
        DMAX=None,
        decay=0.999,
        center=True,
        scale=False,
        epsilon=0.001,
        activation_fn=None,
        param_initializers=None,
        is_training=True,
        reuse=None,
        variables_collections=None,
        outputs_collections=None,
        trainable=True,
        data_format=DATA_FORMAT_NHWC,
        zero_debias_moving_mean=False,
        scope=None):
    """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167.

        "Batch Normalization: Accelerating Deep Network Training by Reducing
        Internal Covariate Shift"

        Sergey Ioffe, Christian Szegedy

    Can be used as a normalizer function for conv2d and fully_connected.

    Note: When is_training is True the moving_mean and moving_variance need to be
    updated, by default the update_ops are placed in `tf.GraphKeys.UPDATE_OPS` so
    they need to be added as a dependency to the `train_op`, example:

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        if update_ops:
        updates = tf.group(*update_ops)
        total_loss = control_flow_ops.with_dependencies([updates], total_loss)

    Args:
        inputs: a tensor with 2 or more dimensions, where the first dimension has
        `batch_size`. The normalization is over all but the last dimension if
        `data_format` is `NHWC` and the second dimension if `data_format` is
        `NCHW`.
        decay: decay for the moving average. Reasonable values for `decay` are close
        to 1.0, typically in the multiple-nines range: 0.999, 0.99, 0.9, etc.
        Lower `decay` value (recommend trying `decay`=0.9) if model experiences
        reasonably good training performance but poor validation and/or test
        performance.
        center: If True, add offset of `beta` to normalized tensor.  If False,
        `beta` is ignored.
        scale: If True, multiply by `gamma`. If False, `gamma` is
        not used. When the next layer is linear (also e.g. `nn.relu`), this can be
        disabled since the scaling can be done by the next layer.
        epsilon: small float added to variance to avoid dividing by zero.
        activation_fn: activation function, default set to None to skip it and
        maintain a linear activation.
        param_initializers: optional initializers for beta, gamma, moving mean and
        moving variance.
        updates_collections: collections to collect the update ops for computation.
        The updates_ops need to be executed with the train_op.
        If None, a control dependency would be added to make sure the updates are
        computed in place.
        is_training: whether or not the layer is in training mode. In training mode
        it would accumulate the statistics of the moments into `moving_mean` and
        `moving_variance` using an exponential moving average with the given
        `decay`. When it is not in training mode then it would use the values of
        the `moving_mean` and the `moving_variance`.
        reuse: whether or not the layer and its variables should be reused. To be
        able to reuse the layer scope must be given.
        variables_collections: optional collections for the variables.
        outputs_collections: collections to add the outputs.
        trainable: If `True` also add variables to the graph collection
        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
        data_format: A string. `NHWC` (default) and `NCHW` are supported.
        zero_debias_moving_mean: Use zero_debias for moving_mean.
        scope: Optional scope for `variable_scope`.

    Returns:
        A `Tensor` representing the output of the operation.

    Raises:
        ValueError: if `data_format` is neither `NHWC` nor `NCHW`.
        ValueError: if the rank of `inputs` is undefined.
        ValueError: if the rank of `inputs` is neither 2 or 4.
        ValueError: if rank or `C` dimension of `inputs` is undefined.
    """
    if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
        raise ValueError('data_format has to be either NCHW or NHWC.')
    with tf.variable_scope(
            scope, 'BatchNorm', [inputs], reuse=reuse) as sc:
        inputs = ops.convert_to_tensor(inputs)
        original_shape = inputs.get_shape()
        original_rank = original_shape.ndims
        if original_rank is None:
            raise ValueError('Inputs %s has undefined rank' % inputs.name)
        elif original_rank not in [2, 4]:
            raise ValueError('Inputs %s has unsupported rank.'
                            ' Expected 2 or 4 but got %d' % (
                                inputs.name, original_rank))
        if original_rank == 2:
            channels = inputs.get_shape()[-1].value
            if channels is None:
                raise ValueError('`C` dimension must be known but is None')
            new_shape = [-1, 1, 1, channels]
            if data_format == DATA_FORMAT_NCHW:
                new_shape = [-1, channels, 1, 1]
            inputs = array_ops.reshape(inputs, new_shape)
        inputs_shape = inputs.get_shape()
        dtype = inputs.dtype.base_dtype
        if data_format == DATA_FORMAT_NHWC:
            params_shape = inputs_shape[-1:]
        else:
            params_shape = inputs_shape[1:2]
        if not params_shape.is_fully_defined():
            raise ValueError('Inputs %s has undefined `C` dimension %s.' %
                            (inputs.name, params_shape))

        if not param_initializers:
            param_initializers = {}
        # Allocate parameters for the beta and gamma of the normalization.
        trainable_beta = trainable and center
        if trainable_beta:
            beta_collections = utils.get_variable_collections(variables_collections,
                                                            'beta')
            beta_initializer = param_initializers.get('beta',
                                                    init_ops.zeros_initializer())
            real_beta = variables.model_variable(
                    'beta',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=beta_initializer,
                    collections=beta_collections,
                    trainable=trainable_beta)
            beta = tf.zeros(params_shape, name='fakebeta')
        else:
            real_beta = tf.zeros(params_shape, name='beta')
            beta = tf.zeros(params_shape, name='fakebeta')
        trainable_gamma = trainable and scale
        if trainable_gamma:
            gamma_collections = utils.get_variable_collections(variables_collections,
                                                            'gamma')
            gamma_initializer = param_initializers.get('gamma',
                                                    init_ops.ones_initializer())
            gamma = variables.model_variable(
                    'gamma',
                    shape=params_shape,
                    dtype=dtype,
                    initializer=gamma_initializer,
                    collections=gamma_collections,
                    trainable=trainable_gamma)
        else:
            gamma = tf.ones(params_shape, name='gamma')

        # Create moving_mean and moving_variance variables and add them to the
        # appropiate collections.
        moving_mean_collections = utils.get_variable_collections(
                variables_collections, 'moving_mean')
        moving_mean_initializer = param_initializers.get(
                'moving_mean', init_ops.zeros_initializer())
        moving_mean = variables.model_variable(
                'moving_mean',
                shape=params_shape,
                dtype=dtype,
                initializer=moving_mean_initializer,
                trainable=False,
                collections=moving_mean_collections)
        moving_variance_collections = utils.get_variable_collections(
                variables_collections, 'moving_variance')
        moving_variance_initializer = param_initializers.get(
                'moving_variance', init_ops.ones_initializer())
        moving_variance = variables.model_variable(
                'moving_variance',
                shape=params_shape,
                dtype=dtype,
                initializer=moving_variance_initializer,
                trainable=False,
                collections=moving_variance_collections)

        def _fused_batch_norm_training():
            outputs, mean, variance = nn.fused_batch_norm(
                    inputs, gamma, beta, epsilon=epsilon, data_format=data_format)
            if renorm:
                moving_inv = math_ops.rsqrt(moving_variance + epsilon)
                r = tf.stop_gradient(tf.clip_by_value(tf.sqrt(variance + epsilon) * moving_inv,
                                                        1/RMAX,
                                                        RMAX))
                d = tf.stop_gradient(tf.clip_by_value((mean - moving_mean) * moving_inv,
                                                        -DMAX,
                                                        DMAX))
                outputs = outputs * r + d
            return outputs, mean, variance
        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(
                    inputs,
                    gamma,
                    beta,
                    mean=moving_mean,
                    variance=moving_variance,
                    epsilon=epsilon,
                    is_training=False,
                    data_format=data_format)
        outputs, mean, variance = utils.smart_cond(is_training,
                                                _fused_batch_norm_training,
                                                _fused_batch_norm_inference)
        outputs = tf.nn.bias_add(outputs, real_beta)

        # If `is_training` doesn't have a constant value, because it is a `Tensor`,
        # a `Variable` or `Placeholder` then is_training_value will be None and
        # `need_updates` will be true.
        is_training_value = utils.constant_value(is_training)
        need_updates = is_training_value is None or is_training_value
        if need_updates:
            moving_vars_fn = lambda: (moving_mean, moving_variance)
            def _delay_updates():
                """Internal function that delay updates moving_vars if is_training."""
                update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean, mean, decay, zero_debias=zero_debias_moving_mean)
                update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance, variance, decay, zero_debias=False)
                return update_moving_mean, update_moving_variance
            update_mean, update_variance = utils.smart_cond(is_training,
                                                            _delay_updates,
                                                            moving_vars_fn)
            ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_mean)
            ops.add_to_collections(ops.GraphKeys.UPDATE_OPS, update_variance)

        outputs.set_shape(inputs_shape)
        if original_shape.ndims == 2:
            outputs = array_ops.reshape(outputs, original_shape)
        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return utils.collect_named_outputs(outputs_collections,
                                        sc.original_name_scope, outputs)
Example #20
0
    def call(self, inputs, training=False):
        # First, compute the axes along which to reduce the mean / variance,
        # as well as the broadcast shape to be used for all parameters.
        input_shape = inputs.get_shape()
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis].value

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = utils.constant_value(training)

        if needs_broadcasting:
            # In this case we must explictly broadcast all parameters.
            if self.center:
                broadcast_beta = array_ops.reshape(self.beta, broadcast_shape)
            else:
                broadcast_beta = None
            if self.scale:
                broadcast_gamma = array_ops.reshape(self.gamma,
                                                    broadcast_shape)
            else:
                broadcast_gamma = None

        # Determines moments
        if training_value is not False:
            if needs_broadcasting:
                broadcast_mean, broadcast_variance = nn.moments(inputs,
                                                                reduction_axes,
                                                                keep_dims=True)
                mean = array_ops.reshape(broadcast_mean, [-1])
                variance = array_ops.reshape(broadcast_variance, [-1])
            else:
                mean, variance = nn.moments(inputs, reduction_axes)

            # Prepare updates if necessary.
            if not self.updates:
                mean_update = moving_averages.assign_moving_average(
                    self.moving_mean, mean, self.momentum, zero_debias=False)
                variance_update = moving_averages.assign_moving_average(
                    self.moving_variance,
                    variance,
                    self.momentum,
                    zero_debias=False)
                # In the future this should be refactored into a self.add_update
                # methods in order to allow for instance-based BN layer sharing
                # across unrelated input streams (e.g. like in Keras).
                self.updates.append(mean_update)
                self.updates.append(variance_update)

        # Normalize batch. We do this inside separate functions for training
        # and inference so as to avoid evaluating both branches.
        def normalize_in_test():
            if needs_broadcasting:
                broadcast_moving_mean = array_ops.reshape(
                    self.moving_mean, broadcast_shape)
                broadcast_moving_variance = array_ops.reshape(
                    self.moving_variance, broadcast_shape)
            arg_mean = broadcast_moving_mean if needs_broadcasting else self.moving_mean
            arg_variance = broadcast_moving_variance if needs_broadcasting else self.moving_variance
            arg_beta = broadcast_beta if needs_broadcasting else (
                self.beta if self.center else None)
            arg_gamma = broadcast_gamma if needs_broadcasting else (
                self.gamma if self.scale else None)
            if self.quantizer is None:
                return nn.batch_normalization(inputs, arg_mean, arg_variance,
                                              arg_beta, arg_gamma,
                                              self.epsilon)
            else:
                return qbatch_normalization(inputs, arg_mean, arg_variance,
                                            arg_beta, arg_gamma, self.epsilon,
                                            self.quantizer)

        def normalize_in_training():
            arg_mean = broadcast_mean if needs_broadcasting else mean
            arg_variance = broadcast_variance if needs_broadcasting else variance
            arg_beta = broadcast_beta if needs_broadcasting else (
                self.beta if self.center else None)
            arg_gamma = broadcast_gamma if needs_broadcasting else (
                self.gamma if self.scale else None)
            if self.quantizer is None:
                return nn.batch_normalization(inputs, arg_mean, arg_variance,
                                              arg_beta, arg_gamma,
                                              self.epsilon)
            else:
                return qbatch_normalization(inputs, arg_mean, arg_variance,
                                            arg_beta, arg_gamma, self.epsilon,
                                            self.quantizer)

        return utils.smart_cond(training, normalize_in_training,
                                normalize_in_test)
Example #21
0
def conditional_batch_norm(inputs,
                           conditional_layer,
                           var_scope_postfix='',
                           decay=0.999,
                           center=True,
                           scale=False,
                           epsilon=0.001,
                           activation_fn=None,
                           param_initializers=None,
                           param_regularizers=None,
                           updates_collections=tf.GraphKeys.UPDATE_OPS,
                           is_training=True,
                           reuse=None,
                           variables_collections=None,
                           outputs_collections=None,
                           trainable=True,
                           data_format=DATA_FORMAT_NHWC,
                           zero_debias_moving_mean=False,
                           renorm=False,
                           renorm_clipping=None,
                           renorm_momentum=0.99,
                           scope=None):
    """Custom implementation of batch norm  to support the optional `conditional_layer` and `var_scope_postfix`.
  For comments on the other parameters, see tensorflow.contrib.layers.python.layers.batch_norm, where this is copied
  from (tf 1.5 version).

  Args:
    conditional_layer: A tensor with 2 dimensions [batch, channels]. If not None, the beta and gamma parameters will
      be conditioned on the `conditional_layer`.
    var_scope_postfix: A string. Append it to the var scopes of all variables other than the weight and bias. e.g.
      var scope of the `gamma` variable becomes `'gamma' + var_scope_postfix`.
  """

    if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC):
        raise ValueError('data_format has to be either NCHW or NHWC.')
    if inputs.dtype != tf.float32:
        raise NotImplementedError(
            'This implementation may not be compatible with mixed precision training.'
        )
    with tf.variable_scope(scope, 'BatchNorm', [inputs], reuse=reuse) as sc:

        if conditional_layer is not None:
            conditional_layer = tf.convert_to_tensor(conditional_layer)
            # Normalizing the conditional layer seems to stabilize training a little.
            conditional_layer = tf.nn.l2_normalize(
                conditional_layer, dim=1, name='normalized_conditional_layer')
            conditional_layer_shape = conditional_layer.get_shape()
            conditional_layer_rank = conditional_layer_shape.ndims
            if conditional_layer_rank is None:
                raise ValueError('Conditional layer %s has undefined rank' %
                                 conditional_layer.name)
            elif conditional_layer_rank != 2:
                raise ValueError('Conditional layer %s is not rank 2.' %
                                 conditional_layer.name)

        inputs = tf.convert_to_tensor(inputs)
        original_shape = inputs.get_shape()
        original_inputs = inputs
        original_rank = original_shape.ndims
        if original_rank is None:
            raise ValueError('Inputs %s has undefined rank' % inputs.name)
        elif original_rank not in [2, 4]:
            raise ValueError('Inputs %s has unsupported rank.'
                             ' Expected 2 or 4 but got %d' %
                             (inputs.name, original_rank))
        if original_rank == 2:
            channels = inputs.get_shape()[-1].value
            if channels is None:
                raise ValueError('`C` dimension must be known but is None')
            new_shape = [-1, 1, 1, channels]
            if data_format == DATA_FORMAT_NCHW:
                new_shape = [-1, channels, 1, 1]
            inputs = tf.reshape(inputs, new_shape)
        inputs_shape = inputs.get_shape()
        if data_format == DATA_FORMAT_NHWC:
            params_shape = inputs_shape[-1:]
        else:
            params_shape = inputs_shape[1:2]
        if not params_shape.is_fully_defined():
            raise ValueError('Inputs %s has undefined `C` dimension %s.' %
                             (inputs.name, params_shape))

        # Allocate parameters for the beta and gamma of the normalization.
        beta_collections = utils.get_variable_collections(
            variables_collections, 'beta')
        variable_dtype = inputs.dtype
        if not param_initializers:
            param_initializers = {}
        if not param_regularizers:
            param_regularizers = {}

        if center:
            beta_scope = 'beta' + var_scope_postfix
            if conditional_layer is not None:
                assert not param_initializers, 'param_initializers are not supported with conditional layer.'
                assert not param_regularizers, 'param_initializers are not supported with conditional layer.'
                beta = get_conditional_batch_norm_param(conditional_layer,
                                                        int(params_shape[-1]),
                                                        scope=beta_scope)
            else:
                # Behaves like normal batch norm.
                beta_collections = utils.get_variable_collections(
                    variables_collections, beta_scope)
                beta_initializer = param_initializers.get(
                    beta_scope, tf.zeros_initializer())
                beta_regularizer = param_regularizers.get('beta')
                beta = variables.model_variable(beta_scope,
                                                shape=params_shape,
                                                dtype=variable_dtype,
                                                initializer=beta_initializer,
                                                regularizer=beta_regularizer,
                                                collections=beta_collections,
                                                trainable=trainable)
        else:
            beta = array_ops.constant(0.0,
                                      dtype=variable_dtype,
                                      shape=params_shape)

        if scale:
            gamma_scope = 'gamma' + var_scope_postfix
            if conditional_layer is not None:
                assert not param_initializers, 'param_initializers are not supported with conditional layer.'
                assert not param_regularizers, 'param_initializers are not supported with conditional layer.'
                delta_gamma = get_conditional_batch_norm_param(
                    conditional_layer,
                    int(params_shape[-1]),
                    scope=gamma_scope)
                # Per https://arxiv.org/pdf/1707.03017.pdf.
                gamma = tf.constant(
                    1.0,
                    dtype=variable_dtype,
                ) + delta_gamma
            else:
                gamma_collections = utils.get_variable_collections(
                    variables_collections, gamma_scope)
                gamma_initializer = param_initializers.get(
                    gamma_scope, tf.ones_initializer())
                gamma_regularizer = param_regularizers.get('gamma')
                gamma = variables.model_variable(gamma_scope,
                                                 shape=params_shape,
                                                 dtype=variable_dtype,
                                                 initializer=gamma_initializer,
                                                 regularizer=gamma_regularizer,
                                                 collections=gamma_collections,
                                                 trainable=trainable)
        else:
            gamma = tf.constant(1.0, dtype=variable_dtype, shape=params_shape)

        # Create moving_mean and moving_variance variables and add them to the
        # appropriate collections. We disable variable partitioning while creating
        # them, because assign_moving_average is not yet supported for partitioned
        # variables (this needs to be handled carefully, as it may break
        # the checkpoint backward compatibility).
        with tf.variable_scope(tf.get_variable_scope()) as local_scope:
            local_scope.set_partitioner(None)
            moving_mean_scope = 'moving_mean' + var_scope_postfix
            moving_mean_collections = utils.get_variable_collections(
                variables_collections, moving_mean_scope)
            moving_mean_initializer = param_initializers.get(
                moving_mean_scope, tf.zeros_initializer())
            moving_mean = variables.model_variable(
                moving_mean_scope,
                shape=params_shape,
                dtype=tf.float32,
                initializer=moving_mean_initializer,
                trainable=False,
                collections=moving_mean_collections)
            moving_variance_scope = 'moving_variance' + var_scope_postfix
            moving_variance_collections = utils.get_variable_collections(
                variables_collections, moving_variance_scope)
            moving_variance_initializer = param_initializers.get(
                moving_variance_scope, tf.ones_initializer())
            moving_variance = variables.model_variable(
                moving_variance_scope,
                shape=params_shape,
                dtype=tf.float32,
                initializer=moving_variance_initializer,
                trainable=False,
                collections=moving_variance_collections)

            if renorm:
                renorm_clipping = renorm_clipping or {}
                keys = ['rmax', 'rmin', 'dmax']
                if set(renorm_clipping) - set(keys):
                    raise ValueError(
                        'renorm_clipping %s contains keys not in %s' %
                        (renorm_clipping, keys))

                # Create variables to maintain the moving mean and standard deviation.
                # These are used in training and thus are different from the moving
                # averages above. The renorm variables are colocated with moving_mean
                # and moving_variance.
                # NOTE: below, the outer `with device` block causes the current device
                # stack to be cleared. The nested ones use a `lambda` to set the desired
                # device and ignore any devices that may be set by the custom getter.
                def _renorm_variable(name, shape):
                    var = variables.model_variable(
                        name=
                        name,  # renorm variable should be dependent on var_scope_postfix.
                        shape=shape,
                        dtype=tf.float32,
                        initializer=param_initializers.get(
                            name, tf.zeros_initializer()),
                        trainable=False)
                    return var

                with ops.device(None):
                    device = ((lambda _: moving_mean.device)
                              if context.executing_eagerly() else
                              moving_mean.device)
                    with ops.device(device):
                        renorm_mean = _renorm_variable(
                            'renorm_mean' + var_scope_postfix, params_shape)
                        renorm_mean_weight = _renorm_variable(
                            'renorm_mean_weight' + var_scope_postfix, ())
                    # We initialize renorm_stddev to 0, and maintain the (0-initialized)
                    # renorm_stddev_weight. This allows us to (1) mix the average
                    # stddev with the minibatch stddev early in training, and (2) compute
                    # the unbiased average stddev by dividing renorm_stddev by the weight.
                    device = ((lambda _: moving_variance.device)
                              if context.executing_eagerly() else
                              moving_variance.device)
                    with ops.device(device):
                        renorm_stddev = _renorm_variable(
                            'renorm_stddev' + var_scope_postfix, params_shape)
                        renorm_stddev_weight = _renorm_variable(
                            'renorm_stddev_weight' + var_scope_postfix, ())

                class dotdict(dict):
                    """dot.notation access to dictionary attributes"""
                    __getattr__ = dict.get
                    __setattr__ = dict.__setitem__
                    __delattr__ = dict.__delitem__

                renorm_params = dotdict({
                    'renorm_mean': renorm_mean,
                    'renorm_mean_weight': renorm_mean_weight,
                    'renorm_stddev': renorm_stddev,
                    'renorm_stddev_weight': renorm_stddev_weight,
                    'renorm_clipping': renorm_clipping,
                    'renorm_momentum': renorm_momentum,
                    'moving_mean': moving_mean,
                    'moving_variance': moving_variance,
                    'epsilon': epsilon
                })
            else:
                renorm_params = None

        def _batch_norm_training():
            # return tf.nn.fused_batch_norm(
            return _batch_norm_aux(inputs,
                                   gamma,
                                   beta,
                                   epsilon=epsilon,
                                   data_format=data_format,
                                   renorm=renorm,
                                   renorm_params=renorm_params)

        def _batch_norm_inference():
            # return tf.nn.fused_batch_norm(
            return _batch_norm_aux(inputs,
                                   gamma,
                                   beta,
                                   mean=tf.cast(moving_mean,
                                                dtype=variable_dtype),
                                   variance=tf.cast(moving_variance,
                                                    dtype=variable_dtype),
                                   epsilon=epsilon,
                                   is_training=False,
                                   data_format=data_format,
                                   renorm=renorm,
                                   renorm_params=renorm_params)

        outputs, mean, variance = utils.smart_cond(is_training,
                                                   _batch_norm_training,
                                                   _batch_norm_inference)

        # If `is_training` doesn't have a constant value, because it is a `Tensor`,
        # a `Variable` or `Placeholder` then is_training_value will be None and
        # `need_updates` will be true.
        is_training_value = utils.constant_value(is_training)
        need_updates = is_training_value is None or is_training_value
        if need_updates:
            if updates_collections is None:
                no_updates = lambda: outputs

                def _force_updates():
                    """Internal function forces updates moving_vars if is_training."""
                    update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean,
                        mean,
                        decay,
                        zero_debias=zero_debias_moving_mean)
                    update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance, variance, decay, zero_debias=False)
                    with tf.control_dependencies(
                        [update_moving_mean, update_moving_variance]):
                        return tf.identity(outputs)

                outputs = utils.smart_cond(is_training, _force_updates,
                                           no_updates)
            else:
                moving_vars_fn = lambda: (moving_mean, moving_variance)

                def _delay_updates():
                    """Internal function that delay updates moving_vars if is_training."""
                    update_moving_mean = moving_averages.assign_moving_average(
                        moving_mean,
                        tf.cast(mean, dtype=moving_mean.dtype),
                        decay,
                        zero_debias=zero_debias_moving_mean)
                    update_moving_variance = moving_averages.assign_moving_average(
                        moving_variance,
                        tf.cast(variance, dtype=moving_variance.dtype),
                        decay,
                        zero_debias=False)
                    return update_moving_mean, update_moving_variance

                update_mean, update_variance = utils.smart_cond(
                    is_training, _delay_updates, moving_vars_fn)
                ops.add_to_collections(updates_collections, update_mean)
                ops.add_to_collections(updates_collections, update_variance)

        outputs.set_shape(inputs_shape)
        if original_shape.ndims == 2:
            outputs = array_ops.reshape(outputs,
                                        array_ops.shape(original_inputs))
        if activation_fn is not None:
            outputs = activation_fn(outputs)
        return utils.collect_named_outputs(outputs_collections, sc.name,
                                           outputs)
Example #22
0
 def test_value(self):
   for v in [True, False, 1, 0, 1.0]:
     value = utils.constant_value(v)
     self.assertEqual(value, v)