def insert_bias_add_op(sess: tf.compat.v1.Session, conv_op_out_tensor: tf.Tensor, new_bias_tensor: tf.Variable, bias_name="bias_value") -> None: """ Insert bias-add op to given conv op. :param sess: model as tf.compat.v1.Session :param conv_op_out_tensor: output of conv op that should feed into the new bias op as tf.Tensor :param new_bias_tensor: bias tensor to be added as tf.Variable :param bias_name: name string for the bias op :return: None , Note : Higher level api needs to perform a save and load to get updated session after usage of this api """ assert conv_op_out_tensor is not None, 'Error, insert_bias_add_op() : conv op output tensor must be provided' with sess.graph.as_default(): if conv_op_out_tensor.consumers(): consumer_list = [] for consumer in conv_op_out_tensor.consumers(): consumer_list.append(consumer) # create new Bias add op bias_add_op = tf.nn.bias_add(value=conv_op_out_tensor, bias=new_bias_tensor, name=bias_name) # use reroute to insert bias-add and swap current outputs of conv with bias-add op ge.reroute_ts(bias_add_op, conv_op_out_tensor, can_modify=consumer_list) # initialize tensor once it's added sess.run(tf.compat.v1.variables_initializer([new_bias_tensor]))
def convert_ckpt_to_pb(path, checkpoint, pb_path): """Convert ckpt files to a *.pb model. Args: * path: the folder of ckpt files. * checkpoint: the *.ckpt file. * pb_path: file path to the *.pb model """ graph = tf.Graph() with graph.as_default(): meta_path = os.path.join(path, checkpoint) + '.meta' saver = tf.train.import_meta_graph(meta_path) inputs = tf.get_collection('images')[0] image = tf.placeholder(dtype=tf.float32, shape=[None, HEIGHT, WIDTH, 3], name='images') graph_editor.reroute_ts(image, inputs) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(graph=graph, config=config) as sess: checkpoint_path = os.path.join(path, checkpoint) saver.restore(sess, checkpoint_path) input_graph_def = graph.as_graph_def() output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, ['output'] ) with tf.gfile.GFile(pb_path, 'wb') as f: f.write(output_graph_def.SerializeToString()) print('successfully convert to pb')
def skip_bn_op(sess: tf.compat.v1.Session, bn_op: tf.Operation, in_tensor: tf.Tensor, out_tensor: tf.Tensor): """ Skip given bn op specified (fused batch norm op). Note: supports only Fused bn op types. :param sess: Tensorflow session :param bn_op: Batchnorm op to be skipped :param in_tensor: Input tensor to the batchnorm op :param out_tensor: Output tensor of the batchnorm op """ if in_tensor is None or out_tensor is None: logger.error( "Error, input and output tensors must be provided for skipping the op" ) assert False else: with sess.graph.as_default(): if bn_op.type in ['FusedBatchNormV3', 'FusedBatchNorm']: ge.detach_outputs(in_tensor.op) ge.reroute_ts(in_tensor, out_tensor) BNUtils.remove_bn_op_from_update_ops(sess, bn_op) else: logger.error("Error, Unknown BN op") assert False
def test_reroute(self): ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1]) self.assertTrue(match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0]) self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) self.assertTrue(match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
def test_reroute(self): ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1]) self.assertTrue( match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op)) self.assertTrue( match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op)) ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0]) self.assertTrue( match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op)) self.assertTrue( match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): """Finds unfused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. Args: graph: Graph to walk and modify. is_training: Bool, True if training. freeze_batch_norm_delay: How many steps to wait before freezing moving mean and variance and using them for batch normalization. Raises: ValueError: When batch norm folding fails. """ input_to_ops_map = input_to_ops.InputToOps(graph) for bn in common.BatchNormGroups(graph): has_scaling = _HasScaling(graph, input_to_ops_map, bn) if not _IsValidUnfusedBatchNorm(graph, bn): continue # The mangling code intimately depends on BatchNorm node's internals. original_op, folded_op = _CreateFoldedOp( graph, bn, has_scaling=has_scaling, freeze_batch_norm_delay=freeze_batch_norm_delay, is_training=is_training) activation = common.GetEndpointActivationOp(graph, bn) if activation: nodes_modified_count = graph_editor.reroute_ts( [folded_op.outputs[0]], [original_op.outputs[0]], can_modify=[activation]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % activation.name) continue # Treat consumer ops in bypass modules differently since they have Add # operations instead of Relu* above. add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') nodes_modified_count = graph_editor.reroute_ts( [folded_op.outputs[0]], [original_op.outputs[0]], can_modify=[add_bypass]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): """Finds unfused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. Args: graph: Graph to walk and modify. is_training: Bool, True if training. freeze_batch_norm_delay: How many steps to wait before freezing moving mean and variance and using them for batch normalization. Raises: ValueError: When batch norm folding fails. """ input_to_ops_map = input_to_ops.InputToOps(graph) for bn in common.BatchNormGroups(graph): has_scaling = _HasScaling(graph, input_to_ops_map, bn) if not _IsValidUnfusedBatchNorm(graph, bn): continue # The mangling code intimately depends on BatchNorm node's internals. original_op, folded_op = _CreateFoldedOp( graph, bn, has_scaling=has_scaling, freeze_batch_norm_delay=freeze_batch_norm_delay, is_training=is_training) activation = common.GetEndpointActivationOp(graph, bn) if activation: nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], [original_op.outputs[0]], can_modify=[activation]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % activation.name) continue # Treat consumer ops in bypass modules differently since they have Add # operations instead of Relu* above. add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], [original_op.outputs[0]], can_modify=[add_bypass]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
def _create_trainable_graph( adjustable_network: AdjustableThresholdsModel, reference_network: RegularModel, learning_strategy: LearningStrategy, ) -> Tuple[tf.Graph, tf.Tensor, TrainOperations]: adj_graph, [adj_input], [adj_output] = adjustable_network.graph_info ref_graph, [ref_input], [ref_output] = reference_network.graph_info trainable_graph = tf.Graph() _copy_graph(adj_graph, trainable_graph) _copy_graph(ref_graph, trainable_graph, 'reference_graph') adj_output = _get_transformed_tensor(adj_output, trainable_graph) ref_output = _get_transformed_tensor(ref_output, trainable_graph, 'reference_graph') adj_input = _get_transformed_tensor(adj_input, trainable_graph) ref_input = _get_transformed_tensor(ref_input, trainable_graph, 'reference_graph') with trainable_graph.as_default(): # pylint: disable=not-context-manager loss = _rmse_loss(ref_output, adj_output) learning_rate = tf.placeholder_with_default( learning_strategy.initial_lr, shape=[], name='lr_for_range_scalers', ) optimizer_type = learning_strategy.optimizer_type optimizer_type = optimizer_type.lower() if optimizer_type == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate) else: raise NotImplementedError( f'optimizer "{optimizer_type}" is not supported') gradients = optimizer.compute_gradients( loss, var_list=tf.trainable_variables()) gradients = optimizer.apply_gradients(gradients) train_operations = TrainOperations(gradients, learning_rate, loss) train_input = adj_input ge.reroute_ts([train_input], [ref_input]) return trainable_graph, train_input, train_operations
def _reroute_if_necessary(self, op: Op, new_op_tensors: List[tf.Tensor]) -> bool: """ Reroute old op and new op outputs if the old op's children's masks are unchanged. If needed, insert downsample or upsample ops. :param op: Original unwinnowed op whose winnowed counterpart has output tensors new_op_tensors :param new_op_tensors: Output tensors of the newly created winnowed version of op :return: True if reroute was performed. """ if len(new_op_tensors) > 1: # Len of new_op_tensors should only be greater than one in the case of ending at a split raise NotImplementedError current_op = op child_op = op.output.consumers[0] while child_op.type == 'branch' or OpConnectivity.get_op_connectivity(ModelApi.tensorflow, child_op.type) == \ ConnectivityType.skip: # For both cases when child op is of type branch or skip connectivity, go one more level down current_op = child_op child_op = child_op.output.consumers[0] # Op output may have multiple consumers, but in the case of non splits, there will be only one tensor shared # among all consumers. Thus looking at only the first consumer is enough to give us the correct tensor to swap. if child_op in self._op_to_mask_dict.keys(): op_mask = self._op_to_mask_dict[op] child_op_mask = self._op_to_mask_dict[child_op] if not child_op_mask.are_masks_unchanged(): return False # find the correct child op input mask if it has multiple input masks prod_index = child_op.get_input_product_index_of_parent(current_op) assert prod_index is not None # did not find any input product that connects to current op new_op_tensor = _insert_downsample_or_upsample_ops_if_needed( new_op_tensors[0], op_mask.output_channel_masks[0], child_op_mask.input_channel_masks[prod_index]) else: prod_index = child_op.get_input_product_index_of_parent(current_op) assert prod_index is not None # did not find any input product that connects to current op new_op_tensor = new_op_tensors[0] # We have hit the end of a string of ops to reduce, and will now connect the newly reduced ops back to the # main graph. This also detaches the old op's output from its old child op old_tensor = child_op.get_input_products( )[prod_index].tensor_dict[child_op] graph_editor.reroute_ts(ts0=new_op_tensor, ts1=old_tensor) return True
def _insert_weight_quantization_ops(self, ops, indices): if (not ops) or (len(ops) != len(indices)): raise ValueError('No weights to quantize!') self._is_train_variable = tf.Variable(initial_value=False, name='training_in_progress', dtype=tf.bool) self._sess.run(self._is_train_variable.initializer) for op, index in zip(ops, indices): # Modify the weight/bias inputs to use the quantized inputs param_in = op.inputs[index] self._log.debug('Quantizing input: %s for op: %s', param_in.name, op.name) # Rename using scope to be clearer what the op is quantizing. If no scope exists, use the default name w_op_name = os.path.split(param_in.name)[0] if not w_op_name: w_op_name = op.name w_op_name = self._get_quantized_name(w_op_name) self._log.info("Adding weight quantization op %s", w_op_name) # CPU device assignment for QcQuantize op if not self._gpu: with tf.device('/cpu:0'): q_op_out = _qcops.qc_quantize_deprecated( name=w_op_name, op_name=w_op_name, training_in_progress=self._is_train_variable, config=int( libpytrext.config_type.CONFIG_TYPE_Q_DQ_PARAMS), bitwidth=self._bw_params, in_tensors=[param_in], fixed_enc_mins=[], fixed_enc_maxs=[], quant_mode=self._quant_mode_str, round_mode=self._round_mode_str, num_tensors=1) # GPU device assignment for QcQuantize op else: q_op_out = _qcops.qc_quantize_deprecated( name=w_op_name, op_name=w_op_name, training_in_progress=self._is_train_variable, config=int(libpytrext.config_type.CONFIG_TYPE_Q_DQ_PARAMS), bitwidth=self._bw_params, in_tensors=[param_in], fixed_enc_mins=[], fixed_enc_maxs=[], quant_mode=self._quant_mode_str, round_mode=self._round_mode_str, num_tensors=1) nodes_modified_count = ge.reroute_ts(tf_ops.convert_to_tensor( q_op_out[0][0]), param_in, can_modify=op) if nodes_modified_count != 1: raise ValueError('Input ' + param_in.name + ' not quantized!')
def _FoldFusedBatchNorms(graph): """Finds fused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. Args: graph: Graph to walk and modify. Raises: ValueError: When batch norm folding fails. """ for match in _FindFusedBatchNorms(graph): scope, sep, _ = match.layer_op.name.rpartition('/') # Make sure new ops are added to `graph` and put on the same device as # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope # named `scope`. Otherwise, TF creates a unique scope whose name starts with # `scope`. with graph.as_default(), graph.name_scope(scope + sep), ops.device( match.bn_op.device): with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep): # new weights = old weights * gamma / sqrt(variance + epsilon) # new biases = -mean * gamma / sqrt(variance + epsilon) + beta multiplier_tensor = match.gamma_tensor * math_ops.rsqrt( match.variance_tensor + match.bn_op.get_attr('epsilon')) bias_tensor = math_ops.subtract( match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias') # The shape of depthwise weights is different, so we need to reshape the # multiplier_tensor to ensure that the scaled_weight_tensor has the # expected shape. if match.layer_op.type == 'DepthwiseConv2dNative': new_shape = [ match.weight_tensor.get_shape().as_list()[2], match.weight_tensor.get_shape().as_list()[3] ] multiplier_tensor = array_ops.reshape( multiplier_tensor, new_shape, name='scale_reshape') # TODO(suharshs): This naming of the following ops needs to carefully # follow the naming expected by quantize.py. Generalize the quantize code # to not require these delicate naming conventions. scaled_weight_tensor = math_ops.multiply( match.weight_tensor, multiplier_tensor, name='mul_fold') new_layer_tensor = _CloneWithNewOperands( match.layer_op, match.input_tensor, scaled_weight_tensor) bias_add_tensor = math_ops.add( new_layer_tensor, bias_tensor, name='add_fold') nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, match.output_tensor) if nodes_modified_count != 1: raise ValueError( 'Unexpected inputs to op: %s' % match.output_tensor.name)
def FoldBatchNorms(graph): """Finds batch norm layers in the graph, folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. Args: graph: Graph to walk and modify. Raises: ValueError: When batch norm folding fails. """ # Fail immediately when the graph contains unsupported fused batch norm ops. if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'): raise ValueError('Fused batch norm is not supported') input_to_ops_map = input_to_ops.InputToOps(graph) for bn in common.BatchNormGroups(graph): has_scaling = _HasScaling(graph, input_to_ops_map, bn) # The mangling code intimately depends on BatchNorm node's internals. original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling) activation = common.GetEndpointActivationOp(graph, bn) if activation: nodes_modified_count = graph_editor.reroute_ts( [folded_op.outputs[0]], [original_op.outputs[0]], can_modify=[activation]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % activation.name) continue # Treat consumer ops in bypass modules differently since they have Add # operations instead of Relu* above. add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') nodes_modified_count = graph_editor.reroute_ts( [folded_op.outputs[0]], [original_op.outputs[0]], can_modify=[add_bypass]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
def FoldBatchNorms(graph): """Finds batch norm layers in the graph, folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. Args: graph: Graph to walk and modify. Raises: ValueError: When batch norm folding fails. """ # Fail immediately when the graph contains unsupported fused batch norm ops. if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'): raise ValueError('Fused batch norm is not supported') input_to_ops_map = input_to_ops.InputToOps(graph) for bn in common.BatchNormGroups(graph): has_scaling = _HasScaling(graph, input_to_ops_map, bn) # The mangling code intimately depends on BatchNorm node's internals. original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling) activation = common.GetEndpointActivationOp(graph, bn) if activation: nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], [original_op.outputs[0]], can_modify=[activation]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % activation.name) continue # Treat consumer ops in bypass modules differently since they have Add # operations instead of Relu* above. add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], [original_op.outputs[0]], can_modify=[add_bypass]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
def replace_relu6_with_relu(sess: tf.compat.v1.Session, relu6_op: tf.Operation): """ replaces existing Relu6 op with a Relu. :param sess : active tf.compat.v1.Session :param relu6_op: Relu6 op to be replaced with Relu :return: """ with sess.graph.as_default(): assert len(relu6_op.inputs) == 1 new_tensor = tf.nn.relu(relu6_op.inputs[0]) # pylint: disable=no-member relu_op = new_tensor.op relu_outputs = list(relu_op.outputs) relu6_outputs = list(relu6_op.outputs) # swap the two tensors using reroute ge.reroute_ts(ts0=relu_outputs, ts1=relu6_outputs) ge.detach_inputs(relu6_op)
def __build_pruned_evaluate_model(self, path=None): ''' build a evaluation model from pruned model ''' # early break for non-primary workers if not self.__is_primary_worker(): return if path is None: path = FLAGS.save_path if not tf.train.checkpoint_exists(path): return with tf.Graph().as_default(): config = tf.ConfigProto() config.gpu_options.visible_device_list = str( # pylint: disable=no-member mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_eval = tf.Session(config=config) self.saver_eval = tf.train.import_meta_graph(path + '.meta') self.saver_eval.restore(self.sess_eval, path) eval_logits = tf.get_collection('logits')[0] tf.add_to_collection('logits_final', eval_logits) eval_images = tf.get_collection('eval_images')[0] tf.add_to_collection('images_final', eval_images) eval_labels = tf.get_collection('eval_labels')[0] mem_images = tf.get_collection('mem_images')[0] mem_labels = tf.get_collection('mem_labels')[0] self.sess_eval.close() graph_editor.reroute_ts(eval_images, mem_images) graph_editor.reroute_ts(eval_labels, mem_labels) self.sess_eval = tf.Session(config=config) self.saver_eval.restore(self.sess_eval, path) trainable_vars = self.trainable_vars loss, accuracy = self.calc_loss(eval_labels, eval_logits, trainable_vars) self.eval_op = [loss] + list(accuracy.values()) self.sm_writer.add_graph(self.sess_eval.graph)
def replace_layer_with_sequential_of_two_layers(self, layer_to_replace: Layer, layer_a: Layer, layer_b: Layer): """ Replaces original layer with two new layers in the graph. Adds two new layers in the database and remove the original layer from database. :param layer_to_replace: layer to replace :param layer_a: layer a :param layer_b: layer b """ old_bias_op = aimet_tensorflow.utils.common.get_succeeding_bias_op( layer_to_replace.module) old_outputs = [ old_bias_op.outputs[0] ] if old_bias_op is not None else [layer_to_replace.module.outputs[0]] new_bias_op = aimet_tensorflow.utils.common.get_succeeding_bias_op( layer_b.module) new_outputs = [ new_bias_op.outputs[0] ] if new_bias_op is not None else [layer_b.module.outputs[0]] consumers = [] for output in old_outputs: for consumer in output.consumers(): consumers.append(consumer) # For each tensor's pair, replaces the end of [t1 = old_outputs] by the end of [t0 = new_outputs] # The end of the tensors in [ts1 = old_outputs] are left dangling _ = graph_editor.reroute_ts(ts0=new_outputs, ts1=old_outputs, can_modify=consumers) # Add the new layer to the database self._compressible_layers[id(layer_a.module)] = layer_a self._compressible_layers[id(layer_b.module)] = layer_b # Remove the the layer being replaced from the database del self._compressible_layers[id(layer_to_replace.module)]
def test_compatibility(self): with self.assertRaises(ValueError): ge.reroute_ts([self.a0, self.b0], [self.a2, self.b2])
def gradients(ys, xs, # pylint: disable: too-many-statements, too-many-branches grad_ys=None, checkpoints='collection', **kwargs): ''' Authors: Tim Salimans & Yaroslav Bulatov memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 'checkpoints' can either be - a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed - a string specifying how this list should be determined. currently we support - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) - 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint ''' # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) debug_print("bwd_ops: {}".format(bwd_ops)) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("fwd_ops: {}".format(fwd_ops)) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if op not in xs_ops] fwd_ops = [op for op in fwd_ops if '/assign' not in op.name] fwd_ops = [op for op in fwd_ops if '/Assign' not in op.name] fwd_ops = [op for op in fwd_ops if '/read' not in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: if checkpoints == 'collection': checkpoints = tf.get_collection('checkpoints') elif checkpoints == 'speed': # checkpoint all expensive ops to maximize running speed checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') elif checkpoints == 'memory': # remove very small tensors and some weird ops def fixdims(t): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors {}".format(ts_filtered)) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print("Rejected bottleneck candidate and ops {}".format( [t] + list(set(ts_all) - set(b_inp) - set(f_inp)))) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # enough bottlenecks found! break if not bottleneck_ts: raise Exception('unable to find bottleneck tensors! please provide checkpoint ' 'nodes manually, or use checkpoints="speed".') # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] else: raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,)) checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: {}".format(checkpoints)) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print("Warning, some input nodes are also checkpoint nodes: {}".format( xs_intersect_checkpoints)) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print("ys: {}, checkpoints:{}, intersect: {}".format( ys, checkpoints, ys_intersect_checkpoints)) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print("Warning, some output nodes are also checkpoints nodes: {}".format( format_ops(ys_intersect_checkpoints))) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)) debug_print("ops_to_copy = {}".format(ops_to_copy)) debug_print("Processing list {}".format(ys)) _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops)) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print("Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops)) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients {}".format(dv)) debug_print("for %s", copied_ys) debug_print("with respect to {}".format(boundary+xs)) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)])} # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list {}".format(ts)) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print("Found {} ops to copy within {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other)) debug_print("ops_to_copy = {}".format(ops_to_copy)) if not ops_to_copy: # we're done! break _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops)) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected_other, checkpoints_other, copied_ops) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other+xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients {}".format(dv)) debug_print("for {}".format(boundary)) debug_print("with respect to {}".format(checkpoints_disconnected_other+xs)) debug_print("with boundary backprop substitutions {}".format(substitute_backprops)) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(var_x): if not isinstance(var_x, tf.IndexedSlices): return var_x assert var_x.dense_shape is not None, \ "memory_saving_gradients encountered sparse gradients of unknown shape" indices = var_x.indices while indices.shape.ndims < var_x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, var_x.values, var_x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs
def _InsertQuantOp( self, context, producer, consumers, name, moving_avg=True, init_min=-6.0, init_max=6.0, delay_requested=True, bits=8, narrow_range=False, ): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context where producer and consumer operations are nested. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. name: Name for the new quantization op within the context. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. init_max: Starting maximum value for the new quantization op. delay_requested: If true, implement quantization delay where needed. False value explicitly disables delay quantization everywhere. bits: Number of bits to use for quantization, must be between 2 and 8. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ scope = context + '/' + name inputs = producer.outputs[0] if moving_avg: quant = (quant_ops.MovingAvgQuantize( inputs, init_min=init_min, init_max=init_max, ema_decay=self.ema_decay, is_training=self.is_training, num_bits=bits, narrow_range=narrow_range, updates_collection=_UPDATE_QUANT_OPS, vars_collection=self.vars_collection, scope=scope)) else: quant = (quant_ops.LastValueQuantize( inputs, init_min=init_min, init_max=init_max, is_training=self.is_training, num_bits=bits, narrow_range=narrow_range, updates_collection=_UPDATE_QUANT_OPS, vars_collection=self.vars_collection, scope=scope)) if delay_requested and self.quant_delay and self.quant_delay > 0: activate_quant = math_ops.greater_equal( training_util.get_or_create_global_step(), self.quant_delay, name=scope + '/activate_quant') quant = control_flow_ops.cond(activate_quant, lambda: quant, lambda: inputs, name=scope + '/delayed_quant') nodes_modified_count = graph_editor.reroute_ts([quant], [inputs], can_modify=consumers) if nodes_modified_count != len(consumers): raise ValueError( 'Some inputs not quantized for ops: [%s]' % ', '.join([consumer.name for consumer in consumers]))
def _split_conv_layer(self, sess, svd_ranks, attr, op_name, bias_op_name=None): """ Split a given conv layer given a rank :param sess: tf.compat.v1.Session :param svd_ranks: Rank to split the layer with (two ranks in case of SSVD) :param attr: Reference to the corresponding layer attribute :param op_name: Name of the op to split :param bias_op_name: Name of the corresponding bias op (if any) :return: None """ # pylint: disable=too-many-statements,too-many-branches,too-many-locals logger.info('Splitting conv op: %s', op_name) # Retrieve the op(s) from the current graph op = sess.graph.get_operation_by_name(op_name) bias_op = None if bias_op_name: bias_op = sess.graph.get_operation_by_name(bias_op_name) # Create new 'conv_a' layer pad_mode = op.get_attr('padding') data_format = op.get_attr('data_format').decode('utf-8') strides = op.get_attr('strides') # Print current conv weight shape query = core.OpQuery(sess.graph) w_shape = query.get_weights_for_op(op).get_shape().as_list() logger.debug('Original %s weight shape: %s', op.name, str(w_shape)) split_weights, weight_sizes = [], [] split_biases, bias_sizes = [], [] # TF weights are in [H,W,I,O] order. We must reshape the split weights to SVD format [O,I,H,W] # and then transpose back # Conv a weights are: [1, 1, w_shape[2], svd_ranks[0]] split_conv_a_w_shape = (svd_ranks[0], w_shape[2], 1, 1) conv_a_weights = np.zeros(split_conv_a_w_shape) # transpose(2,3,1,0) split_weights.append(conv_a_weights.flatten().tolist()) weight_sizes.append(conv_a_weights.size) if bias_op: conv_a_bias = np.zeros(svd_ranks[0]) split_biases.append(conv_a_bias.flatten().tolist()) bias_sizes.append(conv_a_bias.size) num_filters = w_shape[3] if len(svd_ranks) >= 2 and attr.mode == pymo.TYPE_SUCCESSIVE: # Output channels = output_rank (s) num_filters = svd_ranks[1] # Conv b weights are: [w_shape[0],w_shape[1],svd_ranks[0],num_filters] split_conv_b_w_shape = (num_filters, svd_ranks[0], w_shape[0], w_shape[1]) conv_b_weights = np.zeros(split_conv_b_w_shape) conv_b_bias = np.zeros(num_filters) split_weights.append(conv_b_weights.flatten().tolist()) weight_sizes.append(conv_b_weights.size) if bias_op: split_biases.append(conv_b_bias.flatten().tolist()) bias_sizes.append(conv_b_bias.size) # Only create a third conv layer when performing successive SVD if len(svd_ranks) >= 2 and attr.mode == pymo.TYPE_SUCCESSIVE: # Conv c weights are: [1,1,num_filters,w_shape[3]] split_conv_c_w_shape = (w_shape[3], num_filters, 1, 1) conv_c_weights = np.zeros(split_conv_c_w_shape) conv_c_bias = np.zeros(w_shape[3]) split_weights.append(conv_c_weights.flatten().tolist()) weight_sizes.append(conv_c_weights.size) if bias_op: split_biases.append(conv_c_bias.flatten().tolist()) bias_sizes.append(conv_c_bias.size) # Split the weights and biases according to the number of layers and ranks split_weights = self._svd.SplitLayerWeights(op.name, split_weights, weight_sizes, svd_ranks) split_biases = self._svd.SplitLayerBiases(op.name, split_biases, bias_sizes, svd_ranks) if split_weights: conv_a_name = op.name+'_a' conv_a_weights = np.array(split_weights[0]).reshape(split_conv_a_w_shape).transpose(2, 3, 1, 0) conv_a_w = tf.Variable(initial_value=conv_a_weights, name=conv_a_name+'_w', dtype=tf.float32) logger.debug('%s weight shape: %s', conv_a_name, str(conv_a_weights.shape)) # Create conv_a using default strides (1,1) # pylint: disable=no-member conv_acts = tf.nn.conv2d(op.inputs[0], conv_a_w, strides=[1, 1, 1, 1], data_format=data_format, padding=pad_mode, name=op.name+'_a') # dilation_rate=dilation_rate if bias_op: conv_a_bias = tf.Variable(initial_value=split_biases[0], name=conv_a_name+'_bias', dtype=tf.float32) conv_acts = conv_acts + conv_a_bias # tf.nn.bias_add(conv_acts, split_biases[0]) if len(split_weights) > 1: # Create conv_b conv_b_name = op.name+'_b' conv_b_weights = np.array(split_weights[1]).reshape(split_conv_b_w_shape).transpose(2, 3, 1, 0) conv_b_w = tf.Variable(initial_value=conv_b_weights, name=conv_b_name+'_w', dtype=tf.float32) logger.debug('%s weight shape: %s', conv_b_name, str(conv_b_weights.shape)) # pylint: disable=no-member conv_acts = tf.nn.conv2d(conv_acts, conv_b_w, strides=strides, data_format=data_format, padding=pad_mode, name=conv_b_name) #dilation_rate=dilation_rate if bias_op: conv_b_bias = tf.Variable(initial_value=split_biases[1], name=conv_b_name+'_bias', dtype=tf.float32) conv_acts = conv_acts + conv_b_bias # tf.nn.bias_add(conv_acts, split_biases[1]) ratio = self._compute_per_layer_compression_ratio([conv_a_w.shape, conv_b_w.shape], conv_acts.shape, w_shape, "Conv2D") # Only create a third conv layer when performing successive SVD if len(split_weights) > 2 and len(svd_ranks) >= 2 and attr.mode == pymo.TYPE_SUCCESSIVE: # Create conv_c, using default strides (1,1) conv_c_name = op.name+'_c' conv_c_weights = np.array(split_weights[2]).reshape(split_conv_c_w_shape).transpose(2, 3, 1, 0) conv_c_w = tf.Variable(initial_value=conv_c_weights, name=conv_c_name+'_w', dtype=tf.float32) logger.debug('%s weight shape: %s', conv_c_name, str(conv_c_weights.shape)) # pylint: disable=no-member conv_acts = tf.nn.conv2d(conv_acts, conv_c_w, strides=[1, 1, 1, 1], data_format=data_format, padding=pad_mode, name=conv_c_name) if bias_op: conv_c_bias = tf.Variable(initial_value=split_biases[2], name=conv_c_name+'_bias', dtype=tf.float32) conv_acts = conv_acts + conv_c_bias # tf.nn.bias_add(conv_acts, split_biases[2]) consumers = [] rerouted_inputs = [bias_op.outputs[0]] if bias_op else [op.outputs[0]] for inp in rerouted_inputs: for consumer in inp.consumers(): consumers.append(consumer) _ = ge.reroute_ts(conv_acts, rerouted_inputs, can_modify=consumers) return ratio
def replace_input(op, old_input, new_input): """Replaces old input with new input in op""" ge.reroute_ts([new_input], [old_input], can_modify=[op])
def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): """Finds fused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. Args: graph: Graph to walk and modify. is_training: Bool, true if training. freeze_batch_norm_delay: How many steps to wait before freezing moving mean and variance and using them for batch normalization. Raises: ValueError: When batch norm folding fails. """ for match in _FindFusedBatchNorms(graph): scope, sep, _ = match.layer_op.name.rpartition('/') # Make sure new ops are added to `graph` and put on the same device as # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope # named `scope`. Otherwise, TF creates a unique scope whose name starts with # `scope`. with graph.as_default(), graph.name_scope(scope + sep): with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep): # new weights = old weights * gamma / sqrt(variance + epsilon) # new biases = -mean * gamma / sqrt(variance + epsilon) + beta multiplier_tensor = match.gamma_tensor * math_ops.rsqrt( match.variance_tensor + match.bn_op.get_attr('epsilon')) bias_tensor = math_ops.subtract(match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias') correction_scale, correction_recip, correction_offset = None, None, None if is_training: correction_scale, correction_recip, correction_offset = ( _ComputeBatchNormCorrections( context='', match=match, freeze_batch_norm_delay=freeze_batch_norm_delay, fused_batch_norm=True)) # The shape of depthwise weights is different, so we need to reshape the # multiplier_tensor to ensure that the scaled_weight_tensor has the # expected shape. weights = match.weight_tensor if match.layer_op.type == 'DepthwiseConv2dNative': new_shape = [ match.weight_tensor.get_shape().as_list()[2], match.weight_tensor.get_shape().as_list()[3] ] multiplier_tensor = array_ops.reshape(multiplier_tensor, new_shape, name='scale_reshape') if correction_scale is not None: correction_scale = array_ops.reshape( correction_scale, new_shape, name='correction_reshape') if correction_scale is not None: weights = math_ops.multiply(correction_scale, weights, name='correction_mult') scaled_weight_tensor = math_ops.multiply(weights, multiplier_tensor, name='mul_fold') new_layer_tensor = _CloneWithNewOperands(match.layer_op, match.input_tensor, scaled_weight_tensor, match.batch_to_space_op) if correction_recip is not None: new_layer_tensor = math_ops.multiply(correction_recip, new_layer_tensor, name='post_conv_mul') new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset), 'correction_add') bias_add_tensor = math_ops.add(new_layer_tensor, bias_tensor, name='add_fold') nodes_modified_count = graph_editor.reroute_ts( bias_add_tensor, match.output_tensor) if nodes_modified_count == 0: raise ValueError( 'Folding batch norms failed, %s had no outputs.' % match.output_tensor.name)
def replace_varible(self, new_list): assert len(new_list) == len(self._var) for new_var, modify_ops in zip(new_list, self._modify_consumer): for op, v_replaced in modify_ops: ge.reroute_ts(new_var, v_replaced, can_modify=[op])
def __build_pruned_train_model(self, path=None, finetune=False): # pylint: disable=too-many-locals ''' build a training model from pruned model ''' if path is None: path = FLAGS.save_path with tf.Graph().as_default(): config = tf.ConfigProto() config.gpu_options.visible_device_list = str(# pylint: disable=no-member mgw.local_rank() if FLAGS.enbl_multi_gpu else 0) self.sess_train = tf.Session(config=config) self.saver_train = tf.train.import_meta_graph(path + '.meta') self.saver_train.restore(self.sess_train, path) logits = tf.get_collection('logits')[0] train_images = tf.get_collection('train_images')[0] train_labels = tf.get_collection('train_labels')[0] mem_images = tf.get_collection('mem_images')[0] mem_labels = tf.get_collection('mem_labels')[0] self.sess_train.close() graph_editor.reroute_ts(train_images, mem_images) graph_editor.reroute_ts(train_labels, mem_labels) self.sess_train = tf.Session(config=config) self.saver_train.restore(self.sess_train, path) trainable_vars = self.trainable_vars loss, accuracy = self.calc_loss(train_labels, logits, trainable_vars) self.accuracy_keys = list(accuracy.keys()) if FLAGS.enbl_dst: logits_dst = self.learner_dst.calc_logits(self.sess_train, train_images) loss += self.learner_dst.calc_loss(logits, logits_dst) tf.summary.scalar('loss', loss) for key in accuracy.keys(): tf.summary.scalar(key, accuracy[key]) self.summary_op = tf.summary.merge_all() global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, trainable=False) self.global_step = global_step lrn_rate, self.nb_iters_train = setup_lrn_rate( self.global_step, self.model_name, self.dataset_name) if finetune and not FLAGS.cp_retrain: mom_optimizer = tf.train.AdamOptimizer(FLAGS.cp_lrn_rate_ft) self.log_op = [tf.constant(FLAGS.cp_lrn_rate_ft), loss, list(accuracy.values())] else: mom_optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum) self.log_op = [lrn_rate, loss, list(accuracy.values())] if FLAGS.enbl_multi_gpu: optimizer = mgw.DistributedOptimizer(mom_optimizer) else: optimizer = mom_optimizer grads_origin = optimizer.compute_gradients(loss, trainable_vars) grads_pruned, masks = self.__calc_grads_pruned(grads_origin) with tf.control_dependencies(self.update_ops): self.train_op = optimizer.apply_gradients(grads_pruned, global_step=global_step) self.sm_writer.add_graph(tf.get_default_graph()) self.train_init_op = \ tf.initialize_variables(mom_optimizer.variables() + [global_step] + masks) if FLAGS.enbl_multi_gpu: self.bcast_op = mgw.broadcast_global_variables(0)
def _InsertQuantOp(context, name, producer, consumers, is_training, moving_avg=True, init_min=-6.0, init_max=6.0, bits=8, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, narrow_range=False, producer_scope=None, consumer_scope=None): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context where producer and consumer operations are nested. name: Name for the new quantization op within the context. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. is_training: Whether quantizing training graph or eval graph. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. init_max: Starting maximum value for the new quantization op. bits: Number of bits to use for quantization, must be between 2 and 8. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). quant_delay: (Optional, default None) Int, count of global steps for which to delay quantization. This helps weights stabilize at the start of training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. producer_scope: The restriction of producer scope. If not None, the new op will be inserted only when the producer is in this scope. consumer_scope: The restriction of producer scope. If not None, the new op will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ if producer_scope and not producer.name.startswith(producer_scope): logging.info( '_InsertQuantOp ignores context="%s" name="%s" ' 'because producer "%s" is not in scope "%s"', context, name, producer.name, producer_scope) return if consumer_scope: consumers_in_scope = [] for consumer in consumers: if consumer.name.startswith(consumer_scope): consumers_in_scope.append(consumer) else: logging.info( '_InsertQuantOp context="%s" name="%s" ignores ' 'consumer "%s" because it is not in scope "%s"', context, name, consumer.name, consumer_scope) return consumers = consumers_in_scope name_prefix = _AddContextToName(context, name) # This is needed on TPU where name_scope == 'TPUReplicate/loop', and # name_prefix starts with 'TPUReplicate/loop/'; without dropping it # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which # breaks things later. name_scope = ops.get_name_scope() if name_scope: name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/') inputs = producer.outputs[0] # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. fake_quant_ops = set([ 'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs' ]) if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])): return if moving_avg: quant = ( quant_ops.MovingAvgQuantize( inputs, init_min=init_min, init_max=init_max, ema_decay=ema_decay, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) else: quant = ( quant_ops.LastValueQuantize( inputs, init_min=init_min, init_max=init_max, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) if quant_delay and quant_delay > 0: activate_quant = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name=name_prefix + '/activate_quant') quant = control_flow_ops.cond( activate_quant, lambda: quant, lambda: inputs, name=name_prefix + '/delayed_quant') if consumers: tensors_modified_count = graph_editor.reroute_ts( [quant], [inputs], can_modify=consumers) # Some operations can have multiple output tensors going to the same # consumer. Since consumers is a set, we need to ensure that # tensors_modified_count is greater than or equal to the length of the set # of consumers. if tensors_modified_count < len(consumers): raise ValueError('No inputs quantized for ops: [%s]' % ', '.join( [consumer.name for consumer in consumers]))
X_train = X_train_ Y_train = Y_train_ # Reroute tensors to the location of the data on the GPU, add noise and image augmentation (random crops) train_idx = tf.placeholder(tf.int32, shape=[None], name='train_idx') tf.add_to_collection('placeholders', train_idx) input_noise_magnitude = tf.placeholder_with_default( 0.0, shape=[], name='input_noise_magnitude') tf.add_to_collection('placeholders', input_noise_magnitude) X_train = tf.gather(X_train, train_idx) X_train = tf.pad(X_train, paddings=[[0, 0], [3, 3], [3, 3], [0, 0]]) X_train = u.random_crop(X_train, [28, 28, 1]) X_train += input_noise_magnitude * tf.random_normal(tf.shape(X_train), dtype=tf.float32) Y_train = tf.gather(Y_train, train_idx) ge.reroute_ts([X_train, Y_train], [X, labels]) # Define tensor to compute accuracy and confidence metrics Y_pred = tf.argmax(Y, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(Y, dim=-1) Y_pred_prob = tf.reduce_max(prob, axis=-1) A = tf.cast(tf.equal(Y_train, Y_pred), tf.float32) acc = tf.reduce_mean(A) conf = tf.reduce_mean(Y_pred_prob * A) # Start the TF session and load variables config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION with tf.Session(config=config) as sess: if RESET_PARAMETERS:
def _InsertQuantOp(context, name, producer, consumers, is_training, moving_avg=True, init_min=-6.0, init_max=6.0, bits=8, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, narrow_range=False): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context where producer and consumer operations are nested. name: Name for the new quantization op within the context. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. is_training: Whether quantizing training graph or eval graph. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. init_max: Starting maximum value for the new quantization op. bits: Number of bits to use for quantization, must be between 2 and 8. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). quant_delay: (Optional, default None) Int, count of global steps for which to delay quantization. This helps weights stabilize at the start of training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ name_prefix = _AddContextToName(context, name) # This is needed on TPU where name_scope == 'TPUReplicate/loop', and # name_prefix starts with 'TPUReplicate/loop/'; without dropping it # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which # breaks things later. name_prefix = common.DropStringPrefix(name_prefix, ops.get_name_scope() + '/') inputs = producer.outputs[0] if moving_avg: quant = (quant_ops.MovingAvgQuantize(inputs, init_min=init_min, init_max=init_max, ema_decay=ema_decay, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) else: quant = (quant_ops.LastValueQuantize(inputs, init_min=init_min, init_max=init_max, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) if quant_delay and quant_delay > 0: activate_quant = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name=name_prefix + '/activate_quant') quant = control_flow_ops.cond(activate_quant, lambda: quant, lambda: inputs, name=name_prefix + '/delayed_quant') nodes_modified_count = graph_editor.reroute_ts([quant], [inputs], can_modify=consumers) if nodes_modified_count != len(consumers): raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join([consumer.name for consumer in consumers]))
def _InsertQuantOp(context, name, producer, consumers, is_training, moving_avg=True, init_min=-6.0, init_max=6.0, bits=8, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.GLOBAL_VARIABLES, narrow_range=False, producer_scope=None, consumer_scope=None): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context where producer and consumer operations are nested. name: Name for the new quantization op within the context. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. is_training: Whether quantizing training graph or eval graph. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. init_max: Starting maximum value for the new quantization op. bits: Number of bits to use for quantization, must be between 2 and 8. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). quant_delay: (Optional, default None) Int, count of global steps for which to delay quantization. This helps weights stabilize at the start of training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. producer_scope: The restriction of producer scope. If not None, the new op will be inserted only when the producer is in this scope. consumer_scope: The restriction of producer scope. If not None, the new op will be inserted only when all the consumers are in this scope. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ if producer_scope and not producer.name.startswith(producer_scope): logging.info( '_InsertQuantOp ignores context="%s" name="%s" ' 'because producer "%s" is not in scope "%s"', context, name, producer.name, producer_scope) return if consumer_scope: consumers_in_scope = [] for consumer in consumers: if consumer.name.startswith(consumer_scope): consumers_in_scope.append(consumer) else: logging.info( '_InsertQuantOp context="%s" name="%s" ignores ' 'consumer "%s" because it is not in scope "%s"', context, name, consumer.name, consumer_scope) return consumers = consumers_in_scope name_prefix = _AddContextToName(context, name) # This is needed on TPU where name_scope == 'TPUReplicate/loop', and # name_prefix starts with 'TPUReplicate/loop/'; without dropping it # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which # breaks things later. name_scope = ops.get_name_scope() if name_scope: name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/') inputs = producer.outputs[0] # Prevent ops from being quantized multiple times. Bypass ops can sometimes # overlap between multiple matches, so we need to ensure that we don't # add duplicate FakeQuant operations. fake_quant_ops = set( ['FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs']) if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])): return if moving_avg: quant = (quant_ops.MovingAvgQuantize(inputs, init_min=init_min, init_max=init_max, ema_decay=ema_decay, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) else: quant = (quant_ops.LastValueQuantize(inputs, init_min=init_min, init_max=init_max, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) if quant_delay and quant_delay > 0: activate_quant = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name=name_prefix + '/activate_quant') quant = control_flow_ops.cond(activate_quant, lambda: quant, lambda: inputs, name=name_prefix + '/delayed_quant') if consumers: tensors_modified_count = graph_editor.reroute_ts([quant], [inputs], can_modify=consumers) # Some operations can have multiple output tensors going to the same # consumer. Since consumers is a set, we need to ensure that # tensors_modified_count is greater than or equal to the length of the set # of consumers. if tensors_modified_count < len(consumers): raise ValueError( 'No inputs quantized for ops: [%s]' % ', '.join([consumer.name for consumer in consumers]))
def _InsertQuantOp(context, name, producer, consumers, is_training, moving_avg=True, init_min=-6.0, init_max=6.0, bits=8, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, narrow_range=False): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context w,here producer and consumer operations are nested. name: Name for the new quantization op within the context. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. is_training: Whether quantizing training graph or eval graph. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. init_max: Starting maximum value for the new quantization op. bits: Number of bits to use for quantization, must be between 2 and 8. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update quantization intervals for quantizing activations (see here about EMA: https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). quant_delay: (Optional, default None) Int, count of global steps for which to delay quantization. This helps weights stabilize at the start of training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ name_prefix = _AddContextToName(context, name) inputs = producer.outputs[0] if moving_avg: quant = ( quant_ops.MovingAvgQuantize( inputs, init_min=init_min, init_max=init_max, ema_decay=ema_decay, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) else: quant = ( quant_ops.LastValueQuantize( inputs, init_min=init_min, init_max=init_max, is_training=is_training, num_bits=bits, narrow_range=narrow_range, vars_collection=vars_collection, name_prefix=name_prefix)) if quant_delay and quant_delay > 0: activate_quant = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), quant_delay, name=name_prefix + '/activate_quant') quant = control_flow_ops.cond( activate_quant, lambda: quant, lambda: inputs, name=name_prefix + '/delayed_quant') nodes_modified_count = graph_editor.reroute_ts( [quant], [inputs], can_modify=consumers) if nodes_modified_count != len(consumers): raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join( [consumer.name for consumer in consumers]))
def _insert_activation_quantization_ops(self, ops, collect_stats=True): # pylint: disable=too-many-locals encodings = self._activation_encodings['encodings'] if encodings: self._log.info("Using fixed activation encodings") # Add all the activation quantization operations for op in ops: op_name = self._get_quantized_name(op.name) self._log.info("Adding quantization activation %s for %s", op_name, op.name) # When fixed encodings are set we aren't collecting stats, so use the collected encodings enc_mins, enc_maxs = [], [] config = int(libpytrext.config_type.CONFIG_TYPE_UPDATE_STATS) if not collect_stats: if not encodings: raise RuntimeError( 'No activation encodings recorded for activation quantization ops!' ) if op_name not in encodings: raise RuntimeError("Can't find activation encoding for: " + op_name) config = int( libpytrext.config_type.CONFIG_TYPE_Q_DQ_ACTIVATIONS) encoding_tuple = encodings[op_name] enc_mins = encoding_tuple[0] enc_maxs = encoding_tuple[1] self._log.debug("Using min,max encodings: %s,%s", enc_mins, enc_maxs) # Add the new activation quantization op and reroute the outputs from the producer node to the # quantization op and the quantization outputs to the consumer(s) inputs = [output for output in op.outputs] num_tensors = len(inputs) consumers = [] for inp in inputs: for consumer in inp.consumers(): if 'gradients' not in consumer.name: consumers.append(consumer) # CPU device assignment for QcQuantize op if not self._gpu: with tf.device('/cpu:0'): q_op_out = _qcops.qc_quantize_deprecated( name=op_name, op_name=op_name, training_in_progress=self._is_train_variable, config=config, bitwidth=self._bw_acts, in_tensors=inputs, fixed_enc_mins=enc_mins, fixed_enc_maxs=enc_maxs, quant_mode=self._quant_mode_str, round_mode=self._round_mode_str, num_tensors=num_tensors) # GPU device assignment for QcQuantize op else: q_op_out = _qcops.qc_quantize_deprecated( name=op_name, op_name=op_name, training_in_progress=self._is_train_variable, config=config, bitwidth=self._bw_acts, in_tensors=inputs, fixed_enc_mins=enc_mins, fixed_enc_maxs=enc_maxs, quant_mode=self._quant_mode_str, round_mode=self._round_mode_str, num_tensors=num_tensors) qc_outputs = [ tf_ops.convert_to_tensor(q_op_out[i][0]) for i in range(len(inputs)) ] num_rerouted_outputs = ge.reroute_ts(qc_outputs, inputs, can_modify=consumers) if num_rerouted_outputs != len(consumers): raise ValueError('Failed to map ' + str(len(consumers)) + ' quantization output(s). Only mapped ' + str(num_rerouted_outputs)) # Save the activation quantization op name for later if collect_stats: self._quant_act_ops.append(op_name)
def _InsertQuantOp( self, context, producer, consumers, name, moving_avg=True, init_min=-6.0, init_max=6.0, delay_requested=True, bits=8, narrow_range=False,): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context where producer and consumer operations are nested. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. name: Name for the new quantization op within the context. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. init_max: Starting maximum value for the new quantization op. delay_requested: If true, implement quantization delay where needed. False value explicitly disables delay quantization everywhere. bits: Number of bits to use for quantization, must be between 2 and 8. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. Raises: ValueError: When producer operation is not directly connected to the consumer operation. """ scope = context + '/' + name inputs = producer.outputs[0] if moving_avg: quant = (quant_ops.MovingAvgQuantize( inputs, init_min=init_min, init_max=init_max, ema_decay=self.ema_decay, is_training=self.is_training, num_bits=bits, narrow_range=narrow_range, updates_collection=_UPDATE_QUANT_OPS, vars_collection=self.vars_collection, scope=scope)) else: quant = (quant_ops.LastValueQuantize( inputs, init_min=init_min, init_max=init_max, is_training=self.is_training, num_bits=bits, narrow_range=narrow_range, updates_collection=_UPDATE_QUANT_OPS, vars_collection=self.vars_collection, scope=scope)) if delay_requested and self.quant_delay and self.quant_delay > 0: activate_quant = math_ops.greater_equal( training_util.get_or_create_global_step(), self.quant_delay, name=scope + '/activate_quant') quant = control_flow_ops.cond( activate_quant, lambda: quant, lambda: inputs, name=scope + '/delayed_quant') nodes_modified_count = graph_editor.reroute_ts( [quant], [inputs], can_modify=consumers) if nodes_modified_count != len(consumers): raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join([consumer.name for consumer in consumers]))
def _split_fc_layer(self, sess, svd_ranks, op_name, bias_op_name=None): """ Split a given conv layer given a rank :param sess: tf.compat.v1.Session :param svd_ranks: Rank to split the layer with (two ranks in case of SSVD) :param op_name: Name of the op to split :param bias_op_name: Name of the corresponding bias op (if any) :return: None """ # pylint: disable=too-many-statements, too-many-locals logger.info('Splitting fully connected op: %s', op_name) # Retrieve the op(s) from the current graph op = sess.graph.get_operation_by_name(op_name) bias_op = None if bias_op_name: bias_op = sess.graph.get_operation_by_name(bias_op_name) # Print current conv weight shape query = core.OpQuery(sess.graph) w_shape = query.get_weights_for_op(op).get_shape().as_list() logger.debug('Original %s weight shape: %s', op.name, str(w_shape)) split_weights, weight_sizes = [], [] split_biases, bias_sizes = [], [] # FC weights are: [w_shape[2],svd_ranks[0]] in [I,O] order. # We must reshape the split weights to SVD format [O,I] and then transpose to NHWC split_fc_a_w_shape = (svd_ranks[0], w_shape[0]) fc_a_weights = np.zeros(split_fc_a_w_shape) fc_a_bias = np.zeros(svd_ranks[0]) split_weights.append(fc_a_weights.flatten().tolist()) weight_sizes.append(fc_a_weights.size) if bias_op: split_biases.append(fc_a_bias.flatten().tolist()) bias_sizes.append(fc_a_bias.size) # FC b weights are: [svd_ranks[0],num_filters] in [H,W,I,O] order. # We must reshape the split weights to SVD format [O,I,H,W] and then transpose to NHWC split_fc_b_w_shape = (w_shape[1], svd_ranks[0]) fc_b_weights = np.zeros(split_fc_b_w_shape) split_weights.append(fc_b_weights.flatten().tolist()) weight_sizes.append(fc_b_weights.size) if bias_op: fc_b_bias = np.zeros(w_shape[1]) split_biases.append(fc_b_bias.flatten().tolist()) bias_sizes.append(fc_b_bias.size) # Split the weights and biases according to the number of layers and ranks split_weights = self._svd.SplitLayerWeights(op.name, split_weights, weight_sizes, svd_ranks) split_biases = self._svd.SplitLayerBiases(op.name, split_biases, bias_sizes, svd_ranks) if split_weights: fc_a_name = op.name+'_a' fc_a_weights = np.array(split_weights[0]).reshape(split_fc_a_w_shape).transpose(1, 0) fc_a_w = tf.Variable(initial_value=fc_a_weights, name=fc_a_name+'_w', dtype=tf.float32) logger.debug('%s weight shape: %s', fc_a_name, str(fc_a_weights.shape)) # Create fc_a using default strides (1,1) fc_acts = tf.matmul(op.inputs[0], fc_a_w, name=fc_a_name) if bias_op: fc_a_bias = tf.Variable(initial_value=split_biases[0], name=fc_a_name+'_bias', dtype=tf.float32) fc_acts = fc_acts + fc_a_bias if len(split_weights) > 1: # Create fc_b fc_b_name = op.name+'_b' fc_b_weights = np.array(split_weights[1]).reshape(split_fc_b_w_shape).transpose(1, 0) fc_b_w = tf.Variable(initial_value=fc_b_weights, name=fc_b_name+'_w', dtype=tf.float32) logger.debug('%s weight shape: %s', fc_b_name, str(fc_b_weights.shape)) fc_acts = tf.matmul(fc_acts, fc_b_w, name=fc_b_name) if bias_op: fc_b_bias = tf.Variable(initial_value=split_biases[1], name=fc_b_name+'_bias', dtype=tf.float32) fc_acts = fc_acts + fc_b_bias ratio = self._compute_per_layer_compression_ratio([fc_a_w.shape, fc_b_w.shape], fc_acts.shape, w_shape, 'MatMul') consumers = [] rerouted_inputs = [bias_op.outputs[0]] if bias_op else [op.outputs[0]] for inp in rerouted_inputs: for consumer in inp.consumers(): consumers.append(consumer) _ = ge.reroute_ts(fc_acts, rerouted_inputs, can_modify=consumers) return ratio
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): ''' Authors: Tim Salimans & Yaroslav Bulatov memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 'checkpoints' can either be - a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed - a string specifying how this list should be determined. currently we support - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) - 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint ''' # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) debug_print("bwd_ops: %s", bwd_ops) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("fwd_ops: %s", fwd_ops) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) checkpoints = 'collection' # construct list of tensors to checkpoint during forward pass, if not # given as input stereo_checkpoints = ge.filter_ts_from_regex(fwd_ops, "add") motion_checkpoints = ge.filter_ts_from_regex(fwd_ops, "Conv2D") my_ckps = [] for x in motion_checkpoints: if ("motion" in x.name) and ("BatchNorm" not in x.name): my_ckps.append(x) for x in stereo_checkpoints: if ("stereo" in x.name) and ("BatchNorm" not in x.name): my_ckps.append(x) checkpoints = my_ckps checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: %s", checkpoints) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print("Warning, some input nodes are also checkpoint nodes: %s", xs_intersect_checkpoints) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, ys_intersect_checkpoints) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print( "Warning, some output nodes are also checkpoints nodes: %s", format_ops(ys_intersect_checkpoints)) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name + "_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) debug_print("ops_to_copy = %s", ops_to_copy) debug_print("Processing list %s", ys) copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary + xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", copied_ys) debug_print("with respect to %s", boundary + xs) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = { r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)]) } # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list %s", ts) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [ checkpoints_disconnected[r] for r in checkpoints_other ] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other) debug_print("ops_to_copy = %s", ops_to_copy) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected_other, checkpoints_other, copied_ops) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other + xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", boundary) debug_print("with respect to %s", checkpoints_disconnected_other + xs) debug_print("with boundary backprop substitutions %s", substitute_backprops) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = d_xs_new[j] else: d_xs[j] += d_xs_new[j] return d_xs
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, fused_batch_norm): """Computes batch norm correction params. Before batch normalization is frozen: We use batch statistics for batch norm. correction_scale = sigma_b/sigma_mv correction_recip = 1/correction_scale correction_offset = 0 After batch normalization is frozen: correction_scale = sigma_b/sigma_mv correction_recip = 1 correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). Batch norm is frozen if global_step > bn_freeze_delay. The corrections ensure that: a) The weights are quantized after scaling by gamma/sigma_mv. This enables smoother training as the scaling on the weights changes slowly, rather than jump across mini-batches b) Changing the values of the corrections allows for one to switch between using batch statistics to using moving mean and average, without requiring changes to batch_norm Args: context: The scope under which we look for batch norm params match: Object containing required batch norm tensors for correction computation. freeze_batch_norm_delay: Delay in steps at which computation switches from regular batch norm to frozen mean and variance. fused_batch_norm: Bool, true if fused batch norm is used. Returns: A tuple of correction_scale, correction_recip, correction_offset """ g = ops.get_default_graph() prefix = '' if not context else context + '/' with g.name_scope(prefix + 'batch_norm_correction'): recip_sigma_mv = math_ops.rsqrt( match.moving_variance_tensor + match.batch_epsilon) recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon) correction_scale = math_ops.divide( recip_sigma_mv, recip_sigma, name='scale_compute') correction_scale = array_ops.identity( correction_scale, name='correction_scale') correction_recip = math_ops.reciprocal( correction_scale, name='reciprocal_compute') correction_offset = math_ops.multiply( match.gamma_tensor, match.mean_tensor * recip_sigma - match.moving_mean_tensor * recip_sigma_mv, name='offset_compute') if freeze_batch_norm_delay is not None: use_mv_avg = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), freeze_batch_norm_delay, name='use_moving_average') else: use_mv_avg = False bn_decay_zero = 0.0 bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_mean_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') graph_editor.reroute_ts( [bn_decay_mean_out], [match.bn_decay_mean_tensor], can_modify=bn_decay_mean_consumers) if fused_batch_norm is False: bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) bn_decay_var_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') graph_editor.reroute_ts( [bn_decay_var_out], [match.bn_decay_var_tensor], can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, lambda: array_ops.ones(correction_scale.shape), lambda: correction_recip, name='correction_recip') correction_offset = utils.smart_cond( use_mv_avg, lambda: correction_offset, lambda: array_ops.zeros(correction_offset.shape), name='correction_offset') return correction_scale, correction_recip, correction_offset
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): print("-------------------------------") debug_print("Editing model for OME") incpu_count = 0 # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) for index, op in enumerate(bwd_ops): debug_print("bwd_ops: [{}] :{}".format(index, op.name), 1) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) for index, op in enumerate(fwd_ops): debug_print("fwd_ops: [{}] : {}".format(index, op.name), 1) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: # remove very small tensors and some weird ops def fixdims( t ): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [ t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE ] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors {}".format(ts_filtered), 1) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set( ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set( ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection( f_inp) and len(b_inp) + len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print( "Rejected bottleneck candidate and ops {}".format( [t] + list(set(ts_all) - set(b_inp) - set(f_inp))), 2) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt( len(ts_filtered)): # yes, enough bottlenecks found! break if not bottleneck_ts: raise Exception( 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".' ) # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: {}".format(checkpoints), 1) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print( "Warning, some input nodes are also checkpoint nodes: {}".format( xs_intersect_checkpoints), 2) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print( "ys: {}, checkpoints: {}, intersect: {}".format( ys, checkpoints, ys_intersect_checkpoints), 1) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print( "Warning, some output nodes are also checkpoints nodes: {}".format( format_ops(ys_intersect_checkpoints)), 2) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint if not checkpoints: raise Exception('no checkpoints nodes found or given as input! ') debug_print( "Select {} nodes to checkpoint nodes.".format(len(checkpoints)), 0) # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: frontier_ops = set(graph_util.get_consuming_ops(x.op.outputs)) debug_print("my frontier ops: {}".format(frontier_ops), 1) bw_frontier_ops = frontier_ops & set(bwd_ops) debug_print("my bw frontier ops: {}".format(bw_frontier_ops), 1) if len(bw_frontier_ops) > 1: continue if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name + "_sg") else: grad_node = tf.stop_gradient(x) swapout_op = _add_swapout(grad_node.op, grad_node.op.outputs) incpu_count = incpu_count + 1 swapin_op = _add_swapin(swapout_op, bw_frontier_ops, grad_node.op.outputs) checkpoints_disconnected[x] = swapin_op my_add_control_inputs(x, bw_frontier_ops, swapin_op) # control dependency -> swap_in # self._add_control_dependency(src_op, dest_op, swapin_op) # g = tf.get_default_graph() # print(g.get_operations()) # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print( "Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints), 1) debug_print("ops_to_copy = {}".format(ops_to_copy), 1) debug_print("Processing list {}".format(ys), 1) copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print( "Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops), 1) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary + xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients {}".format(dv), 1) debug_print("for {}".format(copied_ys), 1) debug_print("with respect to {}".format(boundary + xs), 1) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = { r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)]) } # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list {}".format(ts), 1) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [ checkpoints_disconnected[r] for r in checkpoints_other ] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print( "Found {} ops to copy within {}, seed {}, stop_at {}".format( len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other), 1) debug_print("ops_to_copy = {}".format(ops_to_copy), 1) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print( "Rewired {} in place of {} restricted to {}".format( checkpoints_disconnected_other, checkpoints_other, copied_ops), 1) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other + xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients {}".format(dv), 1) debug_print("for {}".format(boundary), 1) debug_print( "with respect to {}".format(checkpoints_disconnected_other + xs), 1) debug_print( "with boundary backprop substitutions {}".format( substitute_backprops), 1) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(x): if not isinstance(x, tf.IndexedSlices): return x assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" indices = x.indices while indices.shape.ndims < x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, x.values, x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs
def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): """Finds fused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise convolution. Args: graph: Graph to walk and modify. is_training: Bool, true if training. freeze_batch_norm_delay: How many steps to wait before freezing moving mean and variance and using them for batch normalization. Raises: ValueError: When batch norm folding fails. """ for match in _FindFusedBatchNorms(graph): scope, sep, _ = match.layer_op.name.rpartition('/') # Make sure new ops are added to `graph` and put on the same device as # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope # named `scope`. Otherwise, TF creates a unique scope whose name starts with # `scope`. with graph.as_default(), graph.name_scope(scope + sep): with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep): # new weights = old weights * gamma / sqrt(variance + epsilon) # new biases = -mean * gamma / sqrt(variance + epsilon) + beta multiplier_tensor = match.gamma_tensor * math_ops.rsqrt( match.variance_tensor + match.bn_op.get_attr('epsilon')) bias_tensor = math_ops.subtract( match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias') correction_scale, correction_recip, correction_offset = None, None, None if is_training: correction_scale, correction_recip, correction_offset = ( _ComputeBatchNormCorrections( context='', match=match, freeze_batch_norm_delay=freeze_batch_norm_delay, fused_batch_norm=True)) # The shape of depthwise weights is different, so we need to reshape the # multiplier_tensor to ensure that the scaled_weight_tensor has the # expected shape. weights = match.weight_tensor if match.layer_op.type == 'DepthwiseConv2dNative': new_shape = [ match.weight_tensor.get_shape().as_list()[2], match.weight_tensor.get_shape().as_list()[3] ] multiplier_tensor = array_ops.reshape( multiplier_tensor, new_shape, name='scale_reshape') if correction_scale is not None: correction_scale = array_ops.reshape( correction_scale, new_shape, name='correction_reshape') if correction_scale is not None: weights = math_ops.multiply( correction_scale, weights, name='correction_mult') scaled_weight_tensor = math_ops.multiply( weights, multiplier_tensor, name='mul_fold') new_layer_tensor = _CloneWithNewOperands( match.layer_op, match.input_tensor, scaled_weight_tensor) if correction_recip is not None: new_layer_tensor = math_ops.multiply( correction_recip, new_layer_tensor, name='post_conv_mul') new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset), 'correction_add') bias_add_tensor = math_ops.add( new_layer_tensor, bias_tensor, name='add_fold') nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, match.output_tensor) if nodes_modified_count == 0: raise ValueError('Folding batch norms failed, %s had no outputs.' % match.output_tensor.name)
def AddIntegratedGradientsOps(graph, attribution_tensors, output_tensor, num_evals, attribution_dims_map, zero_baseline_tensors=None, new_output_scope='attribution', baseline_scope='baseline', tensors_to_keep=None): """Modify graph to create ops for computing integrated gradients. Function to modify a tensorflow graph by adding ops for attributing the change in value of a given output tensor, to different input 'attribution_tensors' (see arxiv.org/abs/1703.01365). The first dimension of each attribution_tensor and output_tensor is assumed to be the batch dimension. That is, if we create multiple input values for the attribution tensors, we should be able to concatenate them along the first dimension, and the resulting output tensor should have corresponding values for different values of its first dimension. The attribution works by interpolating between a given input, and a given baseline, to create multiple (num_evals) interpolated inputs. At each interpolated input, we compute the gradient of the output tensor with respect to each attribution tensor. The gradients for each attribution tensor are averaged over all interpolated inputs, to get an attribution score for it. Example Usage: attribution_feed_dict = AddIntegratedGradientsOps(...) Then to get attribution for a given input (specificed by input_feed_dict, relative to a baseline given be baseline_feed_dict): combined_feed_dict = attribution_feed_dict['create_combined_feed_dict']( input_feed_dict, baseline_feed_dict) with graph.as_default(), sess.as_default(): attributions = sess.run( attribution_feed_dict['mean_grads'], combined_feed_dict) for tensor, attribution in zip(attribution_tensors, attributions): print('Attribution for %s: %s' % (tensor.op.name, attribution)) Warning: This function is not compatible with tf.cond. If there is a tf.cond in the graph path between the attribution tensors and the output tensor, the attribution ops may not work. # TODO(manasrj): Make attribution ops compatible with tf.cond. Args: graph: The tf.Graph to add attribution ops to. attribution_tensors: Tensors for which to compute attribution scores. The tensors must satisfy two properties: (1) The output tensor must be computable given values for attribution tensors. (2) Each attribution tensor must be computationally independent of the others, i.e., it should not be the case that one of the attribution tensor's value is completely determined by the values of the other attribution tensors. Properties (1) and (2) ensure the attribution tensors form an input-output cut in the computation graph. output_tensor: Tensor for whose value we are performing attribution. num_evals: Integer scalar. Number of interpolated points at which to evaluate gradients. Higher values of this parameter increase computation time, but also increase accuracy of attributions. attribution_dims_map: Dict mapping attribution tensors to lists of integers. For each attribution_tensor, we compute a separate gradient value for each slice along the dims in the list. For example, if we have a rank 3 attribution tensor T that consists of embeddings lookups, with the first dimension being the batch dimension, and the second dimension being the sparse ids, then setting attribution_dims_map[T] = [1] will give us a separate gradient for each sparse id. If an attribution_tensor has no entry in attribution_dims_map, then the list defaults to []. zero_baseline_tensors: Set of attribution tensors. For each tensor T in this set, we compute gradients with respect to T for all interpolated values of T between the value computed from the input feed, and zero. For each tensor U not in zero_baseline_tensors, we compute gradients for interpolated values between the one derived from the input feed, and the one derived from the baseline feed. new_output_scope: String. New ops needed for computing the output tensor at different interpolated values are created under this scope name. baseline_scope: String. New ops needed for computing attribution tensor interpolated values are created under this scope name. tensors_to_keep: Set of tensors. By default, tensors in the graph between the output_tensor and attribution tensors are copied to a different part of the graph, and evaluated separately for each interpolation. If we want a value to be fixed (only computed for the main input instead of each interpolation), it should be put in tensors_to_keep. Returns: attribution_hooks: Dict with the following keys (among others): mean_grads: List of attribution scores (aligned with attribution_tensors). create_combined_feed_dict: A Function that takes an input feed dict, and optionally, a baseline feed dict, and creates a combined feed dict to pass to sess.run to get attributions. """ ops_to_tensors = lambda ops: [op.outputs[0] for op in ops] attribution_hooks = {} if tensors_to_keep is None: tensors_to_keep = [] else: tensors_to_keep = list(tensors_to_keep) if zero_baseline_tensors is None: zero_baseline_tensors = [] with graph.as_default(): # Compute parts of graph and check correctness. all_ops = graph.get_operations() constant_ops = contrib_graph_editor.select.select_ops( all_ops, positive_filter=lambda x: x.type == 'Const') placeholder_ops = contrib_graph_editor.select.select_ops( all_ops, positive_filter=lambda x: x.type == 'Placeholder') var_read_ops = contrib_graph_editor.select.select_ops('/read$', graph=graph) attr_ops = [t.op for t in attribution_tensors] required_ops = set( contrib_graph_editor.select.get_backward_walk_ops( output_tensor.op, stop_at_ts=(tensors_to_keep + list(attribution_tensors) + ops_to_tensors(var_read_ops) + ops_to_tensors(placeholder_ops)))) # Check that attribution tensors are sufficient to compute output_tensor. forward_ops = set( contrib_graph_editor.select.get_forward_walk_ops(attr_ops + var_read_ops + constant_ops)) assert required_ops.issubset(forward_ops) required_sgv = contrib_graph_editor.subgraph.make_view(required_ops) attribution_subgraph, attribution_transform_info = ( contrib_graph_editor.transform.copy_with_input_replacements( required_sgv, {}, graph, new_output_scope)) attribution_hooks['attribution_subgraph'] = attribution_subgraph attribution_hooks[ 'attribution_transform_info'] = attribution_transform_info # Copy feed to attribution part of graph so we can have one part for # baseline and one for input. backward_ops = contrib_graph_editor.select.get_backward_walk_ops( attr_ops, stop_at_ts=ops_to_tensors(var_read_ops)) backward_sgv = contrib_graph_editor.subgraph.make_view(backward_ops) _, baseline_transform_info = ( contrib_graph_editor.transform.copy_with_input_replacements( backward_sgv, {}, graph, baseline_scope)) attribution_hooks['baseline_transform_info'] = baseline_transform_info # Function to compute combined feed dict. The default setting of # baseline_transform_info is to get around python's late binding. def CreateCombinedFeedDict( input_feed_dict, baseline_feed_dict=None, baseline_transform_info=baseline_transform_info): """Combine baseline and input feed dicts into a common feed dict.""" combined_feed_dict = input_feed_dict.copy() if baseline_feed_dict is None: baseline_feed_dict = input_feed_dict for key, feed_value in baseline_feed_dict.items(): if isinstance(key, tf.Tensor): combined_feed_dict[baseline_transform_info.transformed( key)] = (feed_value) elif isinstance(key, six.text_type): if six.PY2: tensor = graph.get_tensor_by_name(key.decode()) else: tensor = graph.get_tensor_by_name(key) combined_feed_dict[baseline_transform_info.transformed( tensor)] = (feed_value) elif isinstance(key, tf.SparseTensor): sparse_transformed_tensor = tf.SparseTensor( baseline_transform_info.transformed(key.indices), baseline_transform_info.transformed(key.values), baseline_transform_info.transformed(key.dense_shape)) combined_feed_dict[sparse_transformed_tensor] = feed_value else: raise ValueError('Invalid key type %s in Feed Dict.' % type(key)) return combined_feed_dict attribution_hooks['create_combined_feed_dict'] = CreateCombinedFeedDict # Create new tensors with the multipliers to insert after previous ones. attribution_hooks['multipliers'] = [] attribution_hooks['weighted_attribution_tensors'] = [] for attribution_tensor in attribution_tensors: with tf.control_dependencies( [tf.assert_equal(tf.shape(attribution_tensor)[0], 1)]): attribution_dims = (attribution_dims_map[attribution_tensor] if attribution_tensor in attribution_dims_map else []) vocab_size = len(attribution_tensor.get_shape()) attribution_dim_cond = tf.sparse_to_indicator( tf.SparseTensor( tf.expand_dims( tf.range(len(attribution_dims), dtype=tf.int64), 1), attribution_dims, [vocab_size]), vocab_size) base_multiplier_shape = tf.concat([ tf.expand_dims(num_evals, 0), tf.ones_like(tf.shape(attribution_tensor))[1:] ], 0) tile_dims = tf.where( attribution_dim_cond, tf.shape(attribution_tensor), tf.ones_like(tf.shape(attribution_tensor))) pre_tile_multiplier = tf.reshape( tf.range(tf.to_float(num_evals)) / tf.to_float(num_evals - 1), base_multiplier_shape) multiplier = tf.tile(pre_tile_multiplier, tile_dims) if attribution_tensor in zero_baseline_tensors: weighted_attribution_tensor = multiplier * attribution_tensor else: base_attribution_tensor = baseline_transform_info.transformed( attribution_tensor) weighted_attribution_tensor = ( multiplier * attribution_tensor + (1 - multiplier) * base_attribution_tensor) attribution_hooks['weighted_attribution_tensors'].append( weighted_attribution_tensor) attribution_hooks['multipliers'].append(multiplier) contrib_graph_editor.reroute_ts( attribution_hooks['weighted_attribution_tensors'], attribution_tensors, can_modify=attribution_subgraph.ops) g = tf.gradients(attribution_transform_info.transformed(output_tensor), attribution_hooks['multipliers']) attribution_hooks['mean_grads'] = [ tf.reduce_mean(grad, 0) for grad in g ] return attribution_hooks
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): ''' Authors: Tim Salimans & Yaroslav Bulatov memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 'checkpoints' can either be - a list consisting of tensors from the forward pass of the neural net that we should re-use when calculating the gradients in the backward pass all other tensors that do not appear in this list will be re-computed - a string specifying how this list should be determined. currently we support - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, so checkpointing them maximizes the running speed (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) - 'memory': try to minimize the memory usage (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint ''' # print("Calling memsaving gradients with", checkpoints) if not isinstance(ys, list): ys = [ys] if not isinstance(xs, list): xs = [xs] bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True) debug_print("bwd_ops: %s", bwd_ops) # forward ops are all ops that are candidates for recomputation fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], inclusive=True, within_ops=bwd_ops) debug_print("fwd_ops: %s", fwd_ops) # exclude ops with no inputs fwd_ops = [op for op in fwd_ops if op.inputs] # don't recompute xs, remove variables xs_ops = _to_ops(xs) fwd_ops = [op for op in fwd_ops if not op in xs_ops] fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] fwd_ops = [op for op in fwd_ops if not '/read' in op.name] ts_all = ge.filter_ts(fwd_ops, True) # get the tensors ts_all = [t for t in ts_all if '/read' not in t.name] ts_all = set(ts_all) - set(xs) - set(ys) # construct list of tensors to checkpoint during forward pass, if not # given as input if type(checkpoints) is not list: if checkpoints == 'collection': checkpoints = tf.get_collection('checkpoints') elif checkpoints == 'speed': # checkpoint all expensive ops to maximize running speed checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') elif checkpoints == 'memory': # remove very small tensors and some weird ops def fixdims( t ): # tf.Dimension values are not compatible with int, convert manually try: return [int(e if e.value is not None else 64) for e in t] except: return [0] # unknown shape ts_all = [ t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE ] ts_all = [t for t in ts_all if 'L2Loss' not in t.name] ts_all = [t for t in ts_all if 'entropy' not in t.name] ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] ts_all = [t for t in ts_all if 'Switch' not in t.name] ts_all = [t for t in ts_all if 'dropout' not in t.name] # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 ts_all = [t for t in ts_all if 'Cast' not in t.name] # filter out all tensors that are inputs of the backward graph with util.capture_ops() as bwd_ops: tf_gradients(ys, xs, grad_ys, **kwargs) bwd_inputs = [t for op in bwd_ops for t in op.inputs] # list of tensors in forward graph that is in input to bwd graph ts_filtered = list(set(bwd_inputs).intersection(ts_all)) debug_print("Using tensors %s", ts_filtered) # try two slightly different ways of getting bottlenecks tensors # to checkpoint for ts in [ts_filtered, ts_all]: # get all bottlenecks in the graph bottleneck_ts = [] for t in ts: b = set( ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) f = set( ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) # check that there are not shortcuts b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) if not set(b_inp).intersection( f_inp) and len(b_inp) + len(f_inp) >= len(ts_all): bottleneck_ts.append(t) # we have a bottleneck! else: debug_print( "Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) # success? or try again without filtering? if len(bottleneck_ts) >= np.sqrt( len(ts_filtered)): # yes, enough bottlenecks found! break if not bottleneck_ts: raise Exception( 'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".' ) # sort the bottlenecks bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) sorted_bottlenecks = [ t for ts in bottlenecks_sorted_lists for t in ts ] # save an approximately optimal number ~ sqrt(N) N = len(ts_filtered) if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): checkpoints = sorted_bottlenecks else: step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) checkpoints = sorted_bottlenecks[step::step] else: raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints, )) checkpoints = list(set(checkpoints).intersection(ts_all)) # at this point automatic selection happened and checkpoints is list of nodes assert isinstance(checkpoints, list) debug_print("Checkpoint nodes used: %s", checkpoints) # better error handling of special cases # xs are already handled as checkpoint nodes, so no need to include them xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) if xs_intersect_checkpoints: debug_print("Warning, some input nodes are also checkpoint nodes: %s", xs_intersect_checkpoints) ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, ys_intersect_checkpoints) # saving an output node (ys) gives no benefit in memory while creating # new edge cases, exclude them if ys_intersect_checkpoints: debug_print( "Warning, some output nodes are also checkpoints nodes: %s", format_ops(ys_intersect_checkpoints)) # remove initial and terminal nodes from checkpoints list if present checkpoints = list(set(checkpoints) - set(ys) - set(xs)) # check that we have some nodes to checkpoint # if not checkpoints: # raise Exception('no checkpoints nodes found or given as input! ') # disconnect dependencies between checkpointed tensors checkpoints_disconnected = {} for x in checkpoints: if x.op and x.op.name is not None: grad_node = tf.stop_gradient(x, name=x.op.name + "_sg") else: grad_node = tf.stop_gradient(x) checkpoints_disconnected[x] = grad_node # partial derivatives to the checkpointed tensors and xs ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], stop_at_ts=checkpoints, within_ops=fwd_ops) debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) debug_print("ops_to_copy = %s", ops_to_copy) debug_print("Processing list %s", ys) copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) # get gradients with respect to current boundary + original x's copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] boundary = list(checkpoints_disconnected.values()) dv = tf_gradients(ys=copied_ys, xs=boundary + xs, grad_ys=grad_ys, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", copied_ys) debug_print("with respect to %s", boundary + xs) inputs_to_do_before = [y.op for y in ys] if grad_ys is not None: inputs_to_do_before += grad_ys wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes # dictionary of "node: backprop" for nodes in the boundary d_checkpoints = { r: dr for r, dr in zip(checkpoints_disconnected.keys(), dv[:len(checkpoints_disconnected)]) } # partial derivatives to xs (usually the params of the neural net) d_xs = dv[len(checkpoints_disconnected):] # incorporate derivatives flowing through the checkpointed nodes checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) for ts in checkpoints_sorted_lists[::-1]: debug_print("Processing list %s", ts) checkpoints_other = [r for r in checkpoints if r not in ts] checkpoints_disconnected_other = [ checkpoints_disconnected[r] for r in checkpoints_other ] # copy part of the graph below current checkpoint node, stopping at # other checkpoints nodes ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other) debug_print("ops_to_copy = %s", ops_to_copy) if not ops_to_copy: # we're done! break copied_sgv, info = ge.copy_with_input_replacements( ge.sgv(ops_to_copy), {}) for origin_op, op in info._transformed_ops.items(): op._set_device(origin_op.node_def.device) copied_ops = info._transformed_ops.values() debug_print("Copied %s to %s", ops_to_copy, copied_ops) ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) debug_print("Rewired %s in place of %s restricted to %s", checkpoints_disconnected_other, checkpoints_other, copied_ops) # gradient flowing through the checkpointed node boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] substitute_backprops = [d_checkpoints[r] for r in ts] dv = tf_gradients(boundary, checkpoints_disconnected_other + xs, grad_ys=substitute_backprops, **kwargs) debug_print("Got gradients %s", dv) debug_print("for %s", boundary) debug_print("with respect to %s", checkpoints_disconnected_other + xs) debug_print("with boundary backprop substitutions %s", substitute_backprops) inputs_to_do_before = [d_checkpoints[r].op for r in ts] wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) # partial derivatives to the checkpointed nodes for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): if dr is not None: if d_checkpoints[r] is None: d_checkpoints[r] = dr else: d_checkpoints[r] += dr def _unsparsify(x): if not isinstance(x, tf.IndexedSlices): return x assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" indices = x.indices while indices.shape.ndims < x.values.shape.ndims: indices = tf.expand_dims(indices, -1) return tf.scatter_nd(indices, x.values, x.dense_shape) # partial derivatives to xs (usually the params of the neural net) d_xs_new = dv[len(checkpoints_other):] for j in range(len(xs)): if d_xs_new[j] is not None: if d_xs[j] is None: d_xs[j] = _unsparsify(d_xs_new[j]) else: d_xs[j] += _unsparsify(d_xs_new[j]) return d_xs
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, fused_batch_norm): """Computes batch norm correction params. Before batch normalization is frozen: We use batch statistics for batch norm. correction_scale = sigma_b/sigma_mv correction_recip = 1/correction_scale correction_offset = 0 After batch normalization is frozen: correction_scale = sigma_b/sigma_mv correction_recip = 1 correction_offset = gamma*(mu_b/sigma_b-mu_mv/sigma_mv). Batch norm is frozen if global_step > bn_freeze_delay. The corrections ensure that: a) The weights are quantized after scaling by gamma/sigma_mv. This enables smoother training as the scaling on the weights changes slowly, rather than jump across mini-batches b) Changing the values of the corrections allows for one to switch between using batch statistics to using moving mean and average, without requiring changes to batch_norm Args: context: The scope under which we look for batch norm params match: Object containing required batch norm tensors for correction computation. freeze_batch_norm_delay: Delay in steps at which computation switches from regular batch norm to frozen mean and variance. fused_batch_norm: Bool, true if fused batch norm is used. Returns: A tuple of correction_scale, correction_recip, correction_offset """ g = ops.get_default_graph() prefix = '' if not context else context + '/' with g.name_scope(prefix + 'batch_norm_correction'): recip_sigma_mv = math_ops.rsqrt(match.moving_variance_tensor + match.batch_epsilon) recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon) correction_scale = math_ops.divide(recip_sigma_mv, recip_sigma, name='scale_compute') correction_scale = array_ops.identity(correction_scale, name='correction_scale') correction_recip = math_ops.reciprocal(correction_scale, name='reciprocal_compute') correction_offset = math_ops.multiply( match.gamma_tensor, match.mean_tensor * recip_sigma - match.moving_mean_tensor * recip_sigma_mv, name='offset_compute') if freeze_batch_norm_delay is not None: use_mv_avg = math_ops.greater_equal( common.CreateOrGetQuantizationStep(), freeze_batch_norm_delay, name='use_moving_average') else: use_mv_avg = False bn_decay_zero = 0.0 bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers()) bn_decay_mean_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') graph_editor.reroute_ts([bn_decay_mean_out], [match.bn_decay_mean_tensor], can_modify=bn_decay_mean_consumers) if fused_batch_norm is False: bn_decay_var_consumers = list( match.bn_decay_var_tensor.consumers()) bn_decay_var_out = utils.smart_cond( use_mv_avg, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') graph_editor.reroute_ts([bn_decay_var_out], [match.bn_decay_var_tensor], can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( use_mv_avg, lambda: array_ops.ones(correction_scale.shape), lambda: correction_recip, name='correction_recip') correction_offset = utils.smart_cond( use_mv_avg, lambda: correction_offset, lambda: array_ops.zeros(correction_offset.shape), name='correction_offset') return correction_scale, correction_recip, correction_offset