def expand_softmax(tf_graph, tf_op): assert tf_op.input.rank != 0 axis = tf_op.attribs.get('axis') if axis is None: axis = -1 if axis < 0: axis += tf_op.input.rank tf_op.attribs['axis'] = -1 if tf_op.input.rank == 2 and axis == 1: return if axis != tf_op.input.rank - 1: perm = utils.without(range(tf_op.input.rank), axis) + [axis] perm_inv = utils.inverse_permutation(perm) transpose = TFOperation(graph=tf_graph, name="tf.transpose", inputs=tf_op.input, attribs=dict(perm=perm), outputs=TFTensor(graph=tf_graph, name=None, shape=infer.transpose( input=tf_op.input.shape, axes=perm), dtype=tf_op.input.dtype)) tf_op.inputs = transpose.output old_output = tf_op.output tf_op.outputs = TFTensor(graph=tf_graph, name=None, shape=tf_op.input.shape, dtype=tf_op.input.dtype) TFOperation(graph=tf_graph, name="tf.transpose", inputs=tf_op.output, attribs=dict(perm=perm_inv), outputs=old_output) if tf_op.input.rank != 2: shape = [-1, tf_op.input.shape[-1]] reshape = TFOperation(graph=tf_graph, name="tf.reshape", inputs=tf_op.input, attribs=dict(shape=shape), outputs=TFTensor(graph=tf_graph, name=None, shape=infer.reshape( input=tf_op.input.shape, shape=shape), dtype=tf_op.input.dtype)) tf_op.inputs = reshape.output old_output = tf_op.output tf_op.outputs = TFTensor(graph=tf_graph, name=None, shape=list(tf_op.input.shape), dtype=tf_op.input.dtype) TFOperation(graph=tf_graph, name="tf.reshape", inputs=tf_op.output, attribs=dict(shape=old_output.shape), outputs=old_output)
def transform_fused_batch_norm(g, op): # type: (TFGraph, TFOperation)->None VARIANCE_CORRECTION_ENABLED = True in_input = op.inputs[0] in_scale = op.inputs[1] in_offset = op.inputs[2] epsilon = op.attribs["epsilon"] out_y = op.outputs[0] out_batch_mean = op.outputs[1] out_batch_var = op.outputs[2] data_format = op.attribs["data_format"].upper( ) if op.attribs["data_format"] else "NHWC" channel_dim = 1 if data_format == "NCHW" else in_input.rank - 1 rest_count = int(op.inputs[0].count / channel_dim) tensors_to_remove = [] if op.attribs["is_training"]: if VARIANCE_CORRECTION_ENABLED: biased_batch_var = TFTensor(graph=g, shape=list(out_batch_var.shape), dtype=out_batch_var.dtype) const = TFTensor(graph=g, shape=[], dtype=in_input.dtype, data=float(rest_count) / max(rest_count - 1, 1)) TFOperation(graph=g, name="tf.nn.moments", inputs=in_input, attribs=dict(axes=utils.without( range(in_input.rank), channel_dim), keep_dims=False), outputs=(out_batch_mean, biased_batch_var)) TFOperation(graph=g, name="tf.multiply", inputs=(biased_batch_var, const), outputs=out_batch_var) TFOperation(graph=g, name="tf.nn.batch_normalization", inputs=(in_input, out_batch_mean, out_batch_var, in_offset, in_scale), attribs=dict(variance_epsilon=epsilon, _data_format=data_format), outputs=out_y) if len(op.outputs) > 3: # This can happen in gradients out_saved_mean = op.outputs[3] out_saved_var = op.outputs[4] graph_utils.replace_tensor_in_consumers( g, out_saved_mean, out_batch_mean) graph_utils.replace_tensor_in_consumers( g, out_saved_var, out_batch_var) tensors_to_remove += [out_saved_mean, out_saved_var] else: # not VARIANCE_CORRECTION_ENABLED TFOperation(graph=g, name="tf.nn.moments", inputs=in_input, attribs=dict(axes=utils.without( range(in_input.rank), channel_dim), keep_dims=False), outputs=(out_batch_mean, out_batch_var)) TFOperation(graph=g, name="tf.nn.batch_normalization", inputs=(in_input, out_batch_mean, out_batch_var, in_offset, in_scale), attribs=dict(variance_epsilon=epsilon, _data_format=data_format), outputs=out_y) if len(op.outputs) > 3: # This can happen in gradients out_saved_mean = op.outputs[3] out_saved_var = op.outputs[4] graph_utils.replace_tensor_in_consumers( g, out_saved_mean, out_batch_mean) graph_utils.replace_tensor_in_consumers( g, out_saved_var, out_batch_var) tensors_to_remove += [out_saved_mean, out_saved_var] else: # not training in_mean = op.inputs[3] in_variance = op.inputs[4] graph_utils.replace_tensor_in_consumers(g, out_batch_mean, in_mean) graph_utils.replace_tensor_in_consumers(g, out_batch_var, in_variance) tensors_to_remove += [out_batch_mean, out_batch_var] if len(op.outputs) > 3: # This can happen in gradients out_saved_mean = op.outputs[3] out_saved_var = op.outputs[4] graph_utils.replace_tensor_in_consumers(g, out_saved_mean, in_mean) graph_utils.replace_tensor_in_consumers(g, out_saved_var, in_variance) tensors_to_remove += [out_saved_mean, out_saved_var] TFOperation(graph=g, name="tf.nn.batch_normalization", inputs=(in_input, in_mean, in_variance, in_offset, in_scale), attribs=dict(variance_epsilon=epsilon, _data_format=data_format), outputs=out_y) g.remove_operation(op, unlink=True) for t in tensors_to_remove: g.remove_tensor(t)