def _get_queue_ops(self, var_update_op: ops.Operation, source_op: ops.Operation, is_chief: bool, is_trainable: bool) -> List[ops.Operation]: """ Get queue operations for synchronous parameter update. Maintain a list of queues of size 1. The chief machine pushes a token to each queue at the beginning of each update. The other workers then dequeue a token from their corresponding queue if their gradient is sent to the accumulator. The enqueue and dequeue operations are grouped and have to be completed before the model moves on to the next step, thus resulting in synchronous parameter update. Args: var_update_op: The op Returns: A list of queue operations. """ var_op = var_update_op.inputs[UPDATE_OP_VAR_POS].op var_update_sync_queues = \ [data_flow_ops.FIFOQueue(1, [dtypes.bool], shapes=[[]], name='%s_update_sync_queue_%d' % (var_op.name, i), shared_name='%s_update_sync_queue_%d' % (var_op.name, i)) for i in range(self.num_workers)] queue_ops = [] if is_chief: if is_trainable: var_update_deps = [ self._var_op_to_accum_apply_op[var_op], source_op ] else: var_update_deps = [var_update_op] # Chief enqueues tokens to all other workers after executing variable update token = constant_op.constant(False) with ops.control_dependencies(var_update_deps): for i, q in enumerate(var_update_sync_queues): if i != self.worker_id: queue_ops.append(q.enqueue(token)) else: queue_ops.append(gen_control_flow_ops.no_op()) else: # wait for execution of var_update_op if is_trainable: with ops.control_dependencies( [self._var_op_to_accum_apply_op[var_op]]): dequeue = var_update_sync_queues[self.worker_id].dequeue() else: dequeue = var_update_sync_queues[self.worker_id].dequeue() queue_ops.append(dequeue) return queue_ops
def test_parse_name_scope(): with ops.Graph().as_default(): name_scope = 'name_scope/child_name_scope' a = constant_op.constant(5) new_name = ops.prepend_name_scope(a.name, name_scope) assert new_name == 'name_scope/child_name_scope/Const:0' assert name_scope == utils.parse_name_scope(new_name) assert '' == utils.parse_name_scope(a.name) with ops.control_dependencies([no_op(name='my_op')]): b = constant_op.constant(6) name_scope = 'name_scope' new_name = ops.prepend_name_scope(b.op.node_def.input[0], name_scope) assert new_name == '^name_scope/my_op' assert name_scope == utils.parse_name_scope(new_name)
def update_loss_scale(self, finite_grads): del finite_grads return gen_control_flow_ops.no_op()
def close(self, name=None): """See TensorArray.""" return gen_control_flow_ops.no_op(name=name)
def _no_op(): return gen_control_flow_ops.no_op()
def call(self, inputs, training=None, use_moving_statistics=True): """ :param inputs: input features :param training: boolean or boolean Tensor (with shape []) which determines the current training phase :param use_moving_statistics: boolean or boolean Tensor (with shape []) which selects statistics to use when training==True (or the Tensor value) statistics (mean and variance) are from the inputs ! when training==False, if use_moving_statistics==True -> feed forward with moving statistics (updated with operations defined in GraphKeys.UPDATE_OPS) else (use_moving_statistics==False -> feed forward with raw statistics (updated with operations from collections 'UPDATE_BN_OPS' 'RESET_BN_OPS' contains operations to reset these vaiables between inferences. """ in_eager_mode = context.executing_eagerly() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation original_shape = [-1] + inputs.shape.as_list()[1:] expanded_shape = [self.virtual_batch_size, -1] + original_shape[1:] # Will cause errors if virtual_batch_size does not divide the batch size inputs = array_ops.reshape(inputs, expanded_shape) def undo_virtual_batching(outputs): outputs = array_ops.reshape(outputs, original_shape) return outputs if self.fused: outputs = self._fused_batch_norm( inputs, training=training, use_moving_statistics=use_moving_statistics) if self.virtual_batch_size is not None: # Currently never reaches here since fused_batch_norm does not support # virtual batching outputs = undo_virtual_batching(outputs) return outputs # Compute the axes along which to reduce the mean / variance input_shape = inputs.get_shape() ndims = len(input_shape) reduction_axes = [i for i in range(ndims) if i not in self.axis] if self.virtual_batch_size is not None: del reduction_axes[1] # Do not reduce along virtual batch dim # Broadcasting only necessary for single-axis batch norm where the axis is # not the last dimension broadcast_shape = [1] * ndims broadcast_shape[self.axis[0]] = input_shape[self.axis[0]].value def _broadcast(v): if (v is not None and len(v.get_shape()) != ndims and reduction_axes != list(range(ndims - 1))): return array_ops.reshape(v, broadcast_shape) return v scale, offset = _broadcast(self.gamma), _broadcast(self.beta) def _compose_transforms(scale, offset, then_scale, then_offset): if then_scale is not None: scale *= then_scale offset *= then_scale if then_offset is not None: offset += then_offset return (scale, offset) # Determine a boolean value for `training`: could be True, False, or None. training_value = tf_utils.constant_value(training) if training_value is not False: if self.adjustment: adj_scale, adj_bias = self.adjustment(array_ops.shape(inputs)) # Adjust only during training. adj_scale = tf_utils.smart_cond( training, lambda: adj_scale, lambda: array_ops.ones_like(adj_scale)) adj_bias = tf_utils.smart_cond( training, lambda: adj_bias, lambda: array_ops.zeros_like(adj_bias)) scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) # Some of the computations here are not necessary when training==False # but not a constant. However, this makes the code simpler. keep_dims = self.virtual_batch_size is not None or len( self.axis) > 1 # mean and variance of the current batch mean, variance = nn.moments(inputs, reduction_axes, keep_dims=keep_dims) mean = tf_utils.smart_cond( training, lambda: mean, lambda: tf_utils.smart_cond(use_moving_statistics, lambda: self .moving_mean, lambda: self.mean)) variance = tf_utils.smart_cond( training, lambda: variance, lambda: tf_utils.smart_cond( use_moving_statistics, lambda: self.moving_variance, lambda: self.variance)) if self.renorm: r, d, new_mean, new_variance = self._renorm_correction_and_moments( mean, variance, training) # When training, the normalized values (say, x) will be transformed as # x * gamma + beta without renorm, and (x * r + d) * gamma + beta # = x * (r * gamma) + (d * gamma + beta) with renorm. r = _broadcast(array_ops.stop_gradient(r, name='renorm_r')) d = _broadcast(array_ops.stop_gradient(d, name='renorm_d')) scale, offset = _compose_transforms(r, d, scale, offset) else: new_mean, new_variance = mean, variance if self.virtual_batch_size is not None: # This isn't strictly correct since in ghost batch norm, you are # supposed to sequentially update the moving_mean and moving_variance # with each sub-batch. However, since the moving statistics are only # used during evaluation, it is more efficient to just update in one # step and should not make a significant difference in the result. new_mean = math_ops.reduce_mean(mean, axis=1, keepdims=True) new_variance = math_ops.reduce_mean(variance, axis=1, keepdims=True) def _do_update(var, value): if in_eager_mode and not self.trainable: return return self._assign_moving_average(var, value, self.momentum) moving_mean_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_mean, new_mean), lambda: self.moving_mean) moving_variance_update = tf_utils.smart_cond( training, lambda: _do_update(self.moving_variance, new_variance), lambda: self.moving_variance) if not context.executing_eagerly(): self.add_update(moving_mean_update, inputs=True) self.add_update(moving_variance_update, inputs=True) mean_update = self._update_statistics(self.mean, mean, self.n_updates) variance_update = self._update_statistics(self.variance, variance, self.n_updates) with ops.control_dependencies([mean_update, variance_update]): # update n_updates only after updating self.mean and self.variance update_n_updates = state_ops.assign_add(self.n_updates, 1.) ops.add_to_collection('UPDATE_BN_OPS', update_n_updates) reset_mean = state_ops.assign(self.mean, array_ops.zeros_like(self.mean)) reset_variance = state_ops.assign( self.variance, array_ops.zeros_like(self.variance)) reset_n_updates = state_ops.assign(self.n_updates, 0.) with ops.control_dependencies( [reset_mean, reset_variance, reset_n_updates]): reset_bn = gen_control_flow_ops.no_op("ResetBatchNormStats") ops.add_to_collection('RESET_OPS', reset_bn) else: # training == False mean = tf_utils.smart_cond(use_moving_statistics, lambda: self.moving_mean, lambda: self.mean) variance = tf_utils.smart_cond(use_moving_statistics, lambda: self.moving_variance, lambda: self.variance) mean = math_ops.cast(mean, inputs.dtype) variance = math_ops.cast(variance, inputs.dtype) if offset is not None: offset = math_ops.cast(offset, inputs.dtype) outputs = nn.batch_normalization(inputs, _broadcast(mean), _broadcast(variance), offset, scale, self.epsilon) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) if self.virtual_batch_size is not None: outputs = undo_virtual_batching(outputs) return outputs
def _fused_batch_norm(self, inputs, training, use_moving_statistics): """Returns the output of fused batch norm.""" beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const def _fused_batch_norm_training(): return nn.fused_batch_norm(inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format) # use_moving_statistics==True use moving_mean and moving_variance, else mean and variance mean = tf_utils.smart_cond(use_moving_statistics, lambda: self.moving_mean, lambda: self.mean) variance = tf_utils.smart_cond(use_moving_statistics, lambda: self.moving_variance, lambda: self.variance) # these variables will be used in _fused_batch_norm_inference(), thanks to python closure def _fused_batch_norm_inference(): return nn.fused_batch_norm(inputs, gamma, beta, mean=mean, variance=variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) output, mean, variance = tf_utils.smart_cond( training, _fused_batch_norm_training, _fused_batch_norm_inference) # if training == True: mean and variance returned are mean and variance of the current batch # elif training == False: mean and variance return are (self.mean, self.variance) or # (self.moving_mean, self.moving_variance) depending of the value of use_moving_statistics if not self._bessels_correction_test_only: # Remove Bessel's correction to be consistent with non-fused batch norm. # Note that the variance computed by fused batch norm is # with Bessel's correction. sample_size = math_ops.cast( array_ops.size(inputs) / array_ops.size(variance), variance.dtype) factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size variance *= factor training_value = tf_utils.constant_value(training) if training_value is None: momentum = tf_utils.smart_cond(training, lambda: self.momentum, lambda: 1.0) else: momentum = ops.convert_to_tensor(self.momentum) if training_value or training_value is None: # if training, first create operations which update self.mean and self.variance mean_update = self._update_statistics(self.mean, mean, self.n_updates) variance_update = self._update_statistics(self.variance, variance, self.n_updates) with ops.control_dependencies([mean_update, variance_update]): update_n_updates = state_ops.assign_add( self.n_updates, 1., ) # add this combination of operations to a specific collection 'UPDATE_BN_OPS' ops.add_to_collection('UPDATE_BN_OPS', update_n_updates) # operations to reset bn statistics reset_mean = state_ops.assign(self.mean, array_ops.zeros_like(self.mean)) reset_variance = state_ops.assign( self.variance, array_ops.zeros_like(self.variance)) reset_n_updates = state_ops.assign(self.n_updates, 0.) with ops.control_dependencies( [reset_mean, reset_variance, reset_n_updates]): reset_bn = gen_control_flow_ops.no_op("ResetBatchNormStats") ops.add_to_collection('RESET_BN_OPS', reset_bn) # to keep the classical behavior of the Batch Norm ! # update moving averages and add operations to tf.GraphKeys.UPDATE_OPS # these operation must be run when optimizing the network moving_mean_update = self._assign_moving_average( self.moving_mean, mean, momentum) moving_variance_update = self._assign_moving_average( self.moving_variance, variance, momentum) self.add_update(moving_mean_update, inputs=True) self.add_update(moving_variance_update, inputs=True) return output