Exemplo n.º 1
0
    def _get_queue_ops(self, var_update_op: ops.Operation,
                       source_op: ops.Operation, is_chief: bool,
                       is_trainable: bool) -> List[ops.Operation]:
        """
        Get queue operations for synchronous parameter update.

        Maintain a list of queues of size 1. The chief machine pushes a token to each queue at the beginning
        of each update. The other workers then dequeue a token from their corresponding queue if their gradient
        is sent to the accumulator. The enqueue and dequeue operations are grouped and have to be completed
        before the model moves on to the next step, thus resulting in synchronous parameter update.

        Args:
            var_update_op: The op

        Returns:
            A list of queue operations.
        """
        var_op = var_update_op.inputs[UPDATE_OP_VAR_POS].op

        var_update_sync_queues = \
            [data_flow_ops.FIFOQueue(1, [dtypes.bool], shapes=[[]],
                                     name='%s_update_sync_queue_%d' % (var_op.name, i),
                                     shared_name='%s_update_sync_queue_%d' % (var_op.name, i))
             for i in range(self.num_workers)]

        queue_ops = []
        if is_chief:
            if is_trainable:
                var_update_deps = [
                    self._var_op_to_accum_apply_op[var_op], source_op
                ]
            else:
                var_update_deps = [var_update_op]
            # Chief enqueues tokens to all other workers after executing variable update
            token = constant_op.constant(False)
            with ops.control_dependencies(var_update_deps):
                for i, q in enumerate(var_update_sync_queues):
                    if i != self.worker_id:
                        queue_ops.append(q.enqueue(token))
                    else:
                        queue_ops.append(gen_control_flow_ops.no_op())
        else:
            # wait for execution of var_update_op
            if is_trainable:
                with ops.control_dependencies(
                    [self._var_op_to_accum_apply_op[var_op]]):
                    dequeue = var_update_sync_queues[self.worker_id].dequeue()
            else:
                dequeue = var_update_sync_queues[self.worker_id].dequeue()
            queue_ops.append(dequeue)

        return queue_ops
Exemplo n.º 2
0
def test_parse_name_scope():
    with ops.Graph().as_default():
        name_scope = 'name_scope/child_name_scope'
        a = constant_op.constant(5)
        new_name = ops.prepend_name_scope(a.name, name_scope)
        assert new_name == 'name_scope/child_name_scope/Const:0'
        assert name_scope == utils.parse_name_scope(new_name)
        assert '' == utils.parse_name_scope(a.name)

        with ops.control_dependencies([no_op(name='my_op')]):
            b = constant_op.constant(6)
        name_scope = 'name_scope'
        new_name = ops.prepend_name_scope(b.op.node_def.input[0], name_scope)
        assert new_name == '^name_scope/my_op'
        assert name_scope == utils.parse_name_scope(new_name)
Exemplo n.º 3
0
 def update_loss_scale(self, finite_grads):
     del finite_grads
     return gen_control_flow_ops.no_op()
Exemplo n.º 4
0
 def close(self, name=None):
   """See TensorArray."""
   return gen_control_flow_ops.no_op(name=name)
Exemplo n.º 5
0
 def close(self, name=None):
   """See TensorArray."""
   return gen_control_flow_ops.no_op(name=name)
 def update_loss_scale(self, finite_grads):
   del finite_grads
   return gen_control_flow_ops.no_op()
Exemplo n.º 7
0
 def _no_op():
     return gen_control_flow_ops.no_op()
    def call(self, inputs, training=None, use_moving_statistics=True):
        """
        :param inputs: input features
        :param training: boolean or boolean Tensor (with shape []) which determines the current training phase
        :param use_moving_statistics: boolean or boolean Tensor (with shape []) which selects statistics to use
               when training==True (or the Tensor value) statistics (mean and variance) are from the inputs !
               when training==False, if use_moving_statistics==True -> feed forward with moving statistics (updated
                                        with operations defined in GraphKeys.UPDATE_OPS)
                                     else (use_moving_statistics==False -> feed forward with raw statistics (updated
                                        with operations from collections 'UPDATE_BN_OPS'
                                        'RESET_BN_OPS' contains operations to reset these vaiables between inferences.
        """
        in_eager_mode = context.executing_eagerly()
        if self.virtual_batch_size is not None:
            # Virtual batches (aka ghost batches) can be simulated by reshaping the
            # Tensor and reusing the existing batch norm implementation
            original_shape = [-1] + inputs.shape.as_list()[1:]
            expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:]

            # Will cause errors if virtual_batch_size does not divide the batch size
            inputs = array_ops.reshape(inputs, expanded_shape)

            def undo_virtual_batching(outputs):
                outputs = array_ops.reshape(outputs, original_shape)
                return outputs

        if self.fused:
            outputs = self._fused_batch_norm(
                inputs,
                training=training,
                use_moving_statistics=use_moving_statistics)
            if self.virtual_batch_size is not None:
                # Currently never reaches here since fused_batch_norm does not support
                # virtual batching
                outputs = undo_virtual_batching(outputs)
            return outputs

        # Compute the axes along which to reduce the mean / variance
        input_shape = inputs.get_shape()
        ndims = len(input_shape)
        reduction_axes = [i for i in range(ndims) if i not in self.axis]
        if self.virtual_batch_size is not None:
            del reduction_axes[1]  # Do not reduce along virtual batch dim

        # Broadcasting only necessary for single-axis batch norm where the axis is
        # not the last dimension
        broadcast_shape = [1] * ndims
        broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value

        def _broadcast(v):
            if (v is not None and len(v.get_shape()) != ndims
                    and reduction_axes != list(range(ndims - 1))):
                return array_ops.reshape(v, broadcast_shape)
            return v

        scale, offset = _broadcast(self.gamma), _broadcast(self.beta)

        def _compose_transforms(scale, offset, then_scale, then_offset):
            if then_scale is not None:
                scale *= then_scale
                offset *= then_scale
            if then_offset is not None:
                offset += then_offset
            return (scale, offset)

        # Determine a boolean value for `training`: could be True, False, or None.
        training_value = tf_utils.constant_value(training)

        if training_value is not False:
            if self.adjustment:
                adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs))
                # Adjust only during training.
                adj_scale = tf_utils.smart_cond(
                    training, lambda: adj_scale,
                    lambda: array_ops.ones_like(adj_scale))
                adj_bias = tf_utils.smart_cond(
                    training, lambda: adj_bias,
                    lambda: array_ops.zeros_like(adj_bias))
                scale, offset = _compose_transforms(adj_scale, adj_bias, scale,
                                                    offset)

            # Some of the computations here are not necessary when training==False
            # but not a constant. However, this makes the code simpler.
            keep_dims = self.virtual_batch_size is not None or len(
                self.axis) > 1

            # mean and variance of the current batch
            mean, variance = nn.moments(inputs,
                                        reduction_axes,
                                        keep_dims=keep_dims)

            mean = tf_utils.smart_cond(
                training, lambda: mean,
                lambda: tf_utils.smart_cond(use_moving_statistics, lambda: self
                                            .moving_mean, lambda: self.mean))
            variance = tf_utils.smart_cond(
                training, lambda: variance, lambda: tf_utils.smart_cond(
                    use_moving_statistics, lambda: self.moving_variance,
                    lambda: self.variance))

            if self.renorm:
                r, d, new_mean, new_variance = self._renorm_correction_and_moments(
                    mean, variance, training)
                # When training, the normalized values (say, x) will be transformed as
                # x * gamma + beta without renorm, and (x * r + d) * gamma + beta
                # = x * (r * gamma) + (d * gamma + beta) with renorm.
                r = _broadcast(array_ops.stop_gradient(r, name='renorm_r'))
                d = _broadcast(array_ops.stop_gradient(d, name='renorm_d'))
                scale, offset = _compose_transforms(r, d, scale, offset)
            else:
                new_mean, new_variance = mean, variance

            if self.virtual_batch_size is not None:
                # This isn't strictly correct since in ghost batch norm, you are
                # supposed to sequentially update the moving_mean and moving_variance
                # with each sub-batch. However, since the moving statistics are only
                # used during evaluation, it is more efficient to just update in one
                # step and should not make a significant difference in the result.
                new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True)
                new_variance = math_ops.reduce_mean(variance,
                                                    axis=1,
                                                    keepdims=True)

            def _do_update(var, value):
                if in_eager_mode and not self.trainable:
                    return
                return self._assign_moving_average(var, value, self.momentum)

            moving_mean_update = tf_utils.smart_cond(
                training, lambda: _do_update(self.moving_mean, new_mean),
                lambda: self.moving_mean)
            moving_variance_update = tf_utils.smart_cond(
                training,
                lambda: _do_update(self.moving_variance, new_variance),
                lambda: self.moving_variance)

            if not context.executing_eagerly():
                self.add_update(moving_mean_update, inputs=True)
                self.add_update(moving_variance_update, inputs=True)

            mean_update = self._update_statistics(self.mean, mean,
                                                  self.n_updates)
            variance_update = self._update_statistics(self.variance, variance,
                                                      self.n_updates)

            with ops.control_dependencies([mean_update, variance_update]):
                # update n_updates only after updating self.mean and self.variance
                update_n_updates = state_ops.assign_add(self.n_updates, 1.)
                ops.add_to_collection('UPDATE_BN_OPS', update_n_updates)

            reset_mean = state_ops.assign(self.mean,
                                          array_ops.zeros_like(self.mean))
            reset_variance = state_ops.assign(
                self.variance, array_ops.zeros_like(self.variance))
            reset_n_updates = state_ops.assign(self.n_updates, 0.)
            with ops.control_dependencies(
                [reset_mean, reset_variance, reset_n_updates]):
                reset_bn = gen_control_flow_ops.no_op("ResetBatchNormStats")
            ops.add_to_collection('RESET_OPS', reset_bn)

        else:
            # training == False
            mean = tf_utils.smart_cond(use_moving_statistics,
                                       lambda: self.moving_mean,
                                       lambda: self.mean)
            variance = tf_utils.smart_cond(use_moving_statistics,
                                           lambda: self.moving_variance,
                                           lambda: self.variance)

        mean = math_ops.cast(mean, inputs.dtype)
        variance = math_ops.cast(variance, inputs.dtype)
        if offset is not None:
            offset = math_ops.cast(offset, inputs.dtype)
        outputs = nn.batch_normalization(inputs, _broadcast(mean),
                                         _broadcast(variance), offset, scale,
                                         self.epsilon)
        # If some components of the shape got lost due to adjustments, fix that.
        outputs.set_shape(input_shape)

        if self.virtual_batch_size is not None:
            outputs = undo_virtual_batching(outputs)

        return outputs
    def _fused_batch_norm(self, inputs, training, use_moving_statistics):
        """Returns the output of fused batch norm."""
        beta = self.beta if self.center else self._beta_const
        gamma = self.gamma if self.scale else self._gamma_const

        def _fused_batch_norm_training():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       epsilon=self.epsilon,
                                       data_format=self._data_format)

        # use_moving_statistics==True use moving_mean and moving_variance, else mean and variance
        mean = tf_utils.smart_cond(use_moving_statistics,
                                   lambda: self.moving_mean, lambda: self.mean)
        variance = tf_utils.smart_cond(use_moving_statistics,
                                       lambda: self.moving_variance,
                                       lambda: self.variance)

        # these variables will be used in _fused_batch_norm_inference(), thanks to python closure

        def _fused_batch_norm_inference():
            return nn.fused_batch_norm(inputs,
                                       gamma,
                                       beta,
                                       mean=mean,
                                       variance=variance,
                                       epsilon=self.epsilon,
                                       is_training=False,
                                       data_format=self._data_format)

        output, mean, variance = tf_utils.smart_cond(
            training, _fused_batch_norm_training, _fused_batch_norm_inference)
        # if training == True: mean and variance returned are mean and variance of the current batch
        # elif training == False: mean and variance return are (self.mean, self.variance) or
        #   (self.moving_mean, self.moving_variance) depending of the value of use_moving_statistics

        if not self._bessels_correction_test_only:
            # Remove Bessel's correction to be consistent with non-fused batch norm.
            # Note that the variance computed by fused batch norm is
            # with Bessel's correction.
            sample_size = math_ops.cast(
                array_ops.size(inputs) / array_ops.size(variance),
                variance.dtype)
            factor = (sample_size -
                      math_ops.cast(1.0, variance.dtype)) / sample_size
            variance *= factor

        training_value = tf_utils.constant_value(training)

        if training_value is None:
            momentum = tf_utils.smart_cond(training, lambda: self.momentum,
                                           lambda: 1.0)
        else:
            momentum = ops.convert_to_tensor(self.momentum)

        if training_value or training_value is None:
            # if training, first create operations which update self.mean and self.variance
            mean_update = self._update_statistics(self.mean, mean,
                                                  self.n_updates)
            variance_update = self._update_statistics(self.variance, variance,
                                                      self.n_updates)

            with ops.control_dependencies([mean_update, variance_update]):
                update_n_updates = state_ops.assign_add(
                    self.n_updates,
                    1.,
                )

            # add this combination of operations to a specific collection 'UPDATE_BN_OPS'
            ops.add_to_collection('UPDATE_BN_OPS', update_n_updates)

            # operations to reset bn statistics
            reset_mean = state_ops.assign(self.mean,
                                          array_ops.zeros_like(self.mean))
            reset_variance = state_ops.assign(
                self.variance, array_ops.zeros_like(self.variance))
            reset_n_updates = state_ops.assign(self.n_updates, 0.)
            with ops.control_dependencies(
                [reset_mean, reset_variance, reset_n_updates]):
                reset_bn = gen_control_flow_ops.no_op("ResetBatchNormStats")
            ops.add_to_collection('RESET_BN_OPS', reset_bn)

            # to keep the classical behavior of the Batch Norm !
            # update moving averages and add operations to tf.GraphKeys.UPDATE_OPS
            # these operation must be run when optimizing the network
            moving_mean_update = self._assign_moving_average(
                self.moving_mean, mean, momentum)
            moving_variance_update = self._assign_moving_average(
                self.moving_variance, variance, momentum)
            self.add_update(moving_mean_update, inputs=True)
            self.add_update(moving_variance_update, inputs=True)

        return output