Beispiel #1
0
    def insert_bias_add_op(sess: tf.compat.v1.Session,
                           conv_op_out_tensor: tf.Tensor,
                           new_bias_tensor: tf.Variable,
                           bias_name="bias_value") -> None:
        """
        Insert bias-add op to given conv op.
        :param sess: model as tf.compat.v1.Session
        :param conv_op_out_tensor: output of conv op that should feed into the new bias op as tf.Tensor
        :param new_bias_tensor:  bias tensor to be added as tf.Variable
        :param bias_name: name string for the bias op
        :return: None ,
        Note : Higher level api needs to perform a save and load to get updated session after usage of this api
        """

        assert conv_op_out_tensor is not None, 'Error, insert_bias_add_op() : conv op output tensor must be provided'
        with sess.graph.as_default():
            if conv_op_out_tensor.consumers():

                consumer_list = []
                for consumer in conv_op_out_tensor.consumers():
                    consumer_list.append(consumer)

                # create new Bias add op
                bias_add_op = tf.nn.bias_add(value=conv_op_out_tensor,
                                             bias=new_bias_tensor,
                                             name=bias_name)

                # use reroute to insert bias-add and swap current outputs of conv with bias-add op
                ge.reroute_ts(bias_add_op,
                              conv_op_out_tensor,
                              can_modify=consumer_list)

                # initialize tensor once it's added
                sess.run(tf.compat.v1.variables_initializer([new_bias_tensor]))
def convert_ckpt_to_pb(path, checkpoint, pb_path):
    """Convert ckpt files to a *.pb model.
    
    Args:
    * path: the folder of ckpt files.
    * checkpoint: the *.ckpt file.
    * pb_path: file path to the *.pb model
    """
    graph = tf.Graph()
    with graph.as_default():
        meta_path = os.path.join(path, checkpoint) + '.meta'
        saver = tf.train.import_meta_graph(meta_path)
        inputs = tf.get_collection('images')[0]
        image = tf.placeholder(dtype=tf.float32, shape=[None, HEIGHT, WIDTH, 3], name='images')
        graph_editor.reroute_ts(image, inputs)
        
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(graph=graph, config=config) as sess:
        checkpoint_path = os.path.join(path, checkpoint)
        saver.restore(sess, checkpoint_path)
        input_graph_def = graph.as_graph_def()
        output_graph_def = graph_util.convert_variables_to_constants(
                sess,
                input_graph_def,
                ['output']
                )
        with tf.gfile.GFile(pb_path, 'wb') as f:
            f.write(output_graph_def.SerializeToString())
    print('successfully convert to pb')
    def skip_bn_op(sess: tf.compat.v1.Session, bn_op: tf.Operation,
                   in_tensor: tf.Tensor, out_tensor: tf.Tensor):
        """
        Skip given bn op specified (fused batch norm op).
        Note: supports only Fused bn op types.

        :param sess: Tensorflow session
        :param bn_op: Batchnorm op to be skipped
        :param in_tensor: Input tensor to the batchnorm op
        :param out_tensor: Output tensor of the batchnorm op
        """

        if in_tensor is None or out_tensor is None:
            logger.error(
                "Error, input and output tensors must be provided for skipping the op"
            )
            assert False
        else:
            with sess.graph.as_default():
                if bn_op.type in ['FusedBatchNormV3', 'FusedBatchNorm']:
                    ge.detach_outputs(in_tensor.op)
                    ge.reroute_ts(in_tensor, out_tensor)
                    BNUtils.remove_bn_op_from_update_ops(sess, bn_op)
                else:
                    logger.error("Error, Unknown BN op")
                    assert False
Beispiel #4
0
  def test_reroute(self):
    ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1])
    self.assertTrue(match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
    self.assertTrue(match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))

    ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0])
    self.assertTrue(match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
    self.assertTrue(match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
Beispiel #5
0
    def test_reroute(self):
        ge.reroute_ts([self.a0, self.b0], [self.a1, self.b1])
        self.assertTrue(
            match.OpMatcher("c0").input_ops("a0", "b0")(self.c0.op))
        self.assertTrue(
            match.OpMatcher("c1").input_ops("a0", "b0")(self.c1.op))

        ge.reroute_ts([self.a1, self.b1], [self.a0, self.b0])
        self.assertTrue(
            match.OpMatcher("c0").input_ops("a1", "b1")(self.c0.op))
        self.assertTrue(
            match.OpMatcher("c1").input_ops("a1", "b1")(self.c1.op))
Beispiel #6
0
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        if not _IsValidUnfusedBatchNorm(graph, bn):
            continue

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(
            graph,
            bn,
            has_scaling=has_scaling,
            freeze_batch_norm_delay=freeze_batch_norm_delay,
            is_training=is_training)

        activation = common.GetEndpointActivationOp(graph, bn)
        if activation:
            nodes_modified_count = graph_editor.reroute_ts(
                [folded_op.outputs[0]], [original_op.outputs[0]],
                can_modify=[activation])
            if nodes_modified_count != 1:
                raise ValueError('Unexpected inputs to op: %s' %
                                 activation.name)
            continue

        # Treat consumer ops in bypass modules differently since they have Add
        # operations instead of Relu* above.
        add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
        add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
        nodes_modified_count = graph_editor.reroute_ts(
            [folded_op.outputs[0]], [original_op.outputs[0]],
            can_modify=[add_bypass])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
  """Finds unfused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, True if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
  input_to_ops_map = input_to_ops.InputToOps(graph)

  for bn in common.BatchNormGroups(graph):
    has_scaling = _HasScaling(graph, input_to_ops_map, bn)

    if not _IsValidUnfusedBatchNorm(graph, bn):
      continue

    # The mangling code intimately depends on BatchNorm node's internals.
    original_op, folded_op = _CreateFoldedOp(
        graph,
        bn,
        has_scaling=has_scaling,
        freeze_batch_norm_delay=freeze_batch_norm_delay,
        is_training=is_training)

    activation = common.GetEndpointActivationOp(graph, bn)
    if activation:
      nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
                                                     [original_op.outputs[0]],
                                                     can_modify=[activation])
      if nodes_modified_count != 1:
        raise ValueError('Unexpected inputs to op: %s' % activation.name)
      continue

    # Treat consumer ops in bypass modules differently since they have Add
    # operations instead of Relu* above.
    add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
    add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
    nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
                                                   [original_op.outputs[0]],
                                                   can_modify=[add_bypass])
    if nodes_modified_count != 1:
      raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
def _create_trainable_graph(
    adjustable_network: AdjustableThresholdsModel,
    reference_network: RegularModel,
    learning_strategy: LearningStrategy,
) -> Tuple[tf.Graph, tf.Tensor, TrainOperations]:

    adj_graph, [adj_input], [adj_output] = adjustable_network.graph_info
    ref_graph, [ref_input], [ref_output] = reference_network.graph_info

    trainable_graph = tf.Graph()
    _copy_graph(adj_graph, trainable_graph)
    _copy_graph(ref_graph, trainable_graph, 'reference_graph')

    adj_output = _get_transformed_tensor(adj_output, trainable_graph)
    ref_output = _get_transformed_tensor(ref_output, trainable_graph,
                                         'reference_graph')
    adj_input = _get_transformed_tensor(adj_input, trainable_graph)
    ref_input = _get_transformed_tensor(ref_input, trainable_graph,
                                        'reference_graph')

    with trainable_graph.as_default():  # pylint: disable=not-context-manager

        loss = _rmse_loss(ref_output, adj_output)

        learning_rate = tf.placeholder_with_default(
            learning_strategy.initial_lr,
            shape=[],
            name='lr_for_range_scalers',
        )

        optimizer_type = learning_strategy.optimizer_type
        optimizer_type = optimizer_type.lower()
        if optimizer_type == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate)
        else:
            raise NotImplementedError(
                f'optimizer "{optimizer_type}" is not supported')

        gradients = optimizer.compute_gradients(
            loss, var_list=tf.trainable_variables())
        gradients = optimizer.apply_gradients(gradients)

        train_operations = TrainOperations(gradients, learning_rate, loss)
        train_input = adj_input
        ge.reroute_ts([train_input], [ref_input])

        return trainable_graph, train_input, train_operations
    def _reroute_if_necessary(self, op: Op,
                              new_op_tensors: List[tf.Tensor]) -> bool:
        """
        Reroute old op and new op outputs if the old op's children's masks are unchanged.  If needed, insert downsample
        or upsample ops.
        :param op: Original unwinnowed op whose winnowed counterpart has output tensors new_op_tensors
        :param new_op_tensors: Output tensors of the newly created winnowed version of op
        :return: True if reroute was performed.
        """
        if len(new_op_tensors) > 1:
            # Len of new_op_tensors should only be greater than one in the case of ending at a split
            raise NotImplementedError

        current_op = op
        child_op = op.output.consumers[0]
        while child_op.type == 'branch' or OpConnectivity.get_op_connectivity(ModelApi.tensorflow, child_op.type) == \
                ConnectivityType.skip:
            # For both cases when child op is of type branch or skip connectivity, go one more level down
            current_op = child_op
            child_op = child_op.output.consumers[0]

        # Op output may have multiple consumers, but in the case of non splits, there will be only one tensor shared
        # among all consumers.  Thus looking at only the first consumer is enough to give us the correct tensor to swap.
        if child_op in self._op_to_mask_dict.keys():
            op_mask = self._op_to_mask_dict[op]
            child_op_mask = self._op_to_mask_dict[child_op]
            if not child_op_mask.are_masks_unchanged():
                return False

            # find the correct child op input mask if it has multiple input masks
            prod_index = child_op.get_input_product_index_of_parent(current_op)
            assert prod_index is not None  # did not find any input product that connects to current op
            new_op_tensor = _insert_downsample_or_upsample_ops_if_needed(
                new_op_tensors[0], op_mask.output_channel_masks[0],
                child_op_mask.input_channel_masks[prod_index])
        else:
            prod_index = child_op.get_input_product_index_of_parent(current_op)
            assert prod_index is not None  # did not find any input product that connects to current op
            new_op_tensor = new_op_tensors[0]

        # We have hit the end of a string of ops to reduce, and will now connect the newly reduced ops back to the
        # main graph.  This also detaches the old op's output from its old child op
        old_tensor = child_op.get_input_products(
        )[prod_index].tensor_dict[child_op]
        graph_editor.reroute_ts(ts0=new_op_tensor, ts1=old_tensor)
        return True
Beispiel #10
0
    def _insert_weight_quantization_ops(self, ops, indices):
        if (not ops) or (len(ops) != len(indices)):
            raise ValueError('No weights to quantize!')

        self._is_train_variable = tf.Variable(initial_value=False,
                                              name='training_in_progress',
                                              dtype=tf.bool)
        self._sess.run(self._is_train_variable.initializer)

        for op, index in zip(ops, indices):

            # Modify the weight/bias inputs to use the quantized inputs
            param_in = op.inputs[index]
            self._log.debug('Quantizing input: %s for op: %s', param_in.name,
                            op.name)
            # Rename using scope to be clearer what the op is quantizing. If no scope exists, use the default name
            w_op_name = os.path.split(param_in.name)[0]
            if not w_op_name:
                w_op_name = op.name
            w_op_name = self._get_quantized_name(w_op_name)
            self._log.info("Adding weight quantization op %s", w_op_name)
            # CPU device assignment for QcQuantize op
            if not self._gpu:
                with tf.device('/cpu:0'):
                    q_op_out = _qcops.qc_quantize_deprecated(
                        name=w_op_name,
                        op_name=w_op_name,
                        training_in_progress=self._is_train_variable,
                        config=int(
                            libpytrext.config_type.CONFIG_TYPE_Q_DQ_PARAMS),
                        bitwidth=self._bw_params,
                        in_tensors=[param_in],
                        fixed_enc_mins=[],
                        fixed_enc_maxs=[],
                        quant_mode=self._quant_mode_str,
                        round_mode=self._round_mode_str,
                        num_tensors=1)
            # GPU device assignment for QcQuantize op
            else:
                q_op_out = _qcops.qc_quantize_deprecated(
                    name=w_op_name,
                    op_name=w_op_name,
                    training_in_progress=self._is_train_variable,
                    config=int(libpytrext.config_type.CONFIG_TYPE_Q_DQ_PARAMS),
                    bitwidth=self._bw_params,
                    in_tensors=[param_in],
                    fixed_enc_mins=[],
                    fixed_enc_maxs=[],
                    quant_mode=self._quant_mode_str,
                    round_mode=self._round_mode_str,
                    num_tensors=1)

            nodes_modified_count = ge.reroute_ts(tf_ops.convert_to_tensor(
                q_op_out[0][0]),
                                                 param_in,
                                                 can_modify=op)
            if nodes_modified_count != 1:
                raise ValueError('Input ' + param_in.name + ' not quantized!')
Beispiel #11
0
def _FoldFusedBatchNorms(graph):
  """Finds fused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.

  Raises:
    ValueError: When batch norm folding fails.
  """
  for match in _FindFusedBatchNorms(graph):
    scope, sep, _ = match.layer_op.name.rpartition('/')
    # Make sure new ops are added to `graph` and put on the same device as
    # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
    # named `scope`. Otherwise, TF creates a unique scope whose name starts with
    # `scope`.
    with graph.as_default(), graph.name_scope(scope + sep), ops.device(
        match.bn_op.device):
      with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
        # new weights = old weights * gamma / sqrt(variance + epsilon)
        # new biases = -mean * gamma / sqrt(variance + epsilon) + beta
        multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
            match.variance_tensor + match.bn_op.get_attr('epsilon'))
        bias_tensor = math_ops.subtract(
            match.beta_tensor,
            match.mean_tensor * multiplier_tensor,
            name='bias')

        # The shape of depthwise weights is different, so we need to reshape the
        # multiplier_tensor to ensure that the scaled_weight_tensor has the
        # expected shape.
        if match.layer_op.type == 'DepthwiseConv2dNative':
          new_shape = [
              match.weight_tensor.get_shape().as_list()[2],
              match.weight_tensor.get_shape().as_list()[3]
          ]
          multiplier_tensor = array_ops.reshape(
              multiplier_tensor, new_shape, name='scale_reshape')

      # TODO(suharshs): This naming of the following ops needs to carefully
      # follow the naming expected by quantize.py. Generalize the quantize code
      # to not require these delicate naming conventions.
      scaled_weight_tensor = math_ops.multiply(
          match.weight_tensor, multiplier_tensor, name='mul_fold')

      new_layer_tensor = _CloneWithNewOperands(
          match.layer_op, match.input_tensor, scaled_weight_tensor)

      bias_add_tensor = math_ops.add(
          new_layer_tensor, bias_tensor, name='add_fold')

      nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
                                                     match.output_tensor)
      if nodes_modified_count != 1:
        raise ValueError(
            'Unexpected inputs to op: %s' % match.output_tensor.name)
Beispiel #12
0
def FoldBatchNorms(graph):
    """Finds batch norm layers in the graph, folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.

  Raises:
    ValueError: When batch norm folding fails.
  """
    # Fail immediately when the graph contains unsupported fused batch norm ops.
    if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'):
        raise ValueError('Fused batch norm is not supported')

    input_to_ops_map = input_to_ops.InputToOps(graph)

    for bn in common.BatchNormGroups(graph):
        has_scaling = _HasScaling(graph, input_to_ops_map, bn)

        # The mangling code intimately depends on BatchNorm node's internals.
        original_op, folded_op = _CreateFoldedOp(graph,
                                                 bn,
                                                 has_scaling=has_scaling)

        activation = common.GetEndpointActivationOp(graph, bn)
        if activation:
            nodes_modified_count = graph_editor.reroute_ts(
                [folded_op.outputs[0]], [original_op.outputs[0]],
                can_modify=[activation])
            if nodes_modified_count != 1:
                raise ValueError('Unexpected inputs to op: %s' %
                                 activation.name)
            continue

        # Treat consumer ops in bypass modules differently since they have Add
        # operations instead of Relu* above.
        add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
        add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
        nodes_modified_count = graph_editor.reroute_ts(
            [folded_op.outputs[0]], [original_op.outputs[0]],
            can_modify=[add_bypass])
        if nodes_modified_count != 1:
            raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
Beispiel #13
0
def FoldBatchNorms(graph):
  """Finds batch norm layers in the graph, folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.

  Raises:
    ValueError: When batch norm folding fails.
  """
  # Fail immediately when the graph contains unsupported fused batch norm ops.
  if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'):
    raise ValueError('Fused batch norm is not supported')

  input_to_ops_map = input_to_ops.InputToOps(graph)

  for bn in common.BatchNormGroups(graph):
    has_scaling = _HasScaling(graph, input_to_ops_map, bn)

    # The mangling code intimately depends on BatchNorm node's internals.
    original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling)

    activation = common.GetEndpointActivationOp(graph, bn)
    if activation:
      nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
                                                     [original_op.outputs[0]],
                                                     can_modify=[activation])
      if nodes_modified_count != 1:
        raise ValueError('Unexpected inputs to op: %s' % activation.name)
      continue

    # Treat consumer ops in bypass modules differently since they have Add
    # operations instead of Relu* above.
    add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
    add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
    nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
                                                   [original_op.outputs[0]],
                                                   can_modify=[add_bypass])
    if nodes_modified_count != 1:
      raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
Beispiel #14
0
def replace_relu6_with_relu(sess: tf.compat.v1.Session,
                            relu6_op: tf.Operation):
    """
    replaces existing Relu6 op with a Relu.
    :param sess : active tf.compat.v1.Session
    :param relu6_op: Relu6 op to be replaced with Relu
    :return:
    """
    with sess.graph.as_default():
        assert len(relu6_op.inputs) == 1
        new_tensor = tf.nn.relu(relu6_op.inputs[0])  # pylint: disable=no-member
        relu_op = new_tensor.op

        relu_outputs = list(relu_op.outputs)
        relu6_outputs = list(relu6_op.outputs)

        # swap the two tensors using reroute
        ge.reroute_ts(ts0=relu_outputs, ts1=relu6_outputs)

        ge.detach_inputs(relu6_op)
Beispiel #15
0
    def __build_pruned_evaluate_model(self, path=None):
        ''' build a evaluation model from pruned model '''
        # early break for non-primary workers
        if not self.__is_primary_worker():
            return

        if path is None:
            path = FLAGS.save_path

        if not tf.train.checkpoint_exists(path):
            return

        with tf.Graph().as_default():
            config = tf.ConfigProto()
            config.gpu_options.visible_device_list = str(  # pylint: disable=no-member
                mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
            self.sess_eval = tf.Session(config=config)
            self.saver_eval = tf.train.import_meta_graph(path + '.meta')
            self.saver_eval.restore(self.sess_eval, path)
            eval_logits = tf.get_collection('logits')[0]
            tf.add_to_collection('logits_final', eval_logits)
            eval_images = tf.get_collection('eval_images')[0]
            tf.add_to_collection('images_final', eval_images)
            eval_labels = tf.get_collection('eval_labels')[0]
            mem_images = tf.get_collection('mem_images')[0]
            mem_labels = tf.get_collection('mem_labels')[0]

            self.sess_eval.close()

            graph_editor.reroute_ts(eval_images, mem_images)
            graph_editor.reroute_ts(eval_labels, mem_labels)

            self.sess_eval = tf.Session(config=config)
            self.saver_eval.restore(self.sess_eval, path)
            trainable_vars = self.trainable_vars
            loss, accuracy = self.calc_loss(eval_labels, eval_logits,
                                            trainable_vars)
            self.eval_op = [loss] + list(accuracy.values())
            self.sm_writer.add_graph(self.sess_eval.graph)
Beispiel #16
0
    def replace_layer_with_sequential_of_two_layers(self,
                                                    layer_to_replace: Layer,
                                                    layer_a: Layer,
                                                    layer_b: Layer):
        """
        Replaces original layer with two new layers in the graph.
        Adds two new layers in the database and remove the original layer from database.

        :param layer_to_replace: layer to replace
        :param layer_a: layer a
        :param layer_b: layer b
        """

        old_bias_op = aimet_tensorflow.utils.common.get_succeeding_bias_op(
            layer_to_replace.module)
        old_outputs = [
            old_bias_op.outputs[0]
        ] if old_bias_op is not None else [layer_to_replace.module.outputs[0]]

        new_bias_op = aimet_tensorflow.utils.common.get_succeeding_bias_op(
            layer_b.module)
        new_outputs = [
            new_bias_op.outputs[0]
        ] if new_bias_op is not None else [layer_b.module.outputs[0]]

        consumers = []

        for output in old_outputs:

            for consumer in output.consumers():
                consumers.append(consumer)

        # For each tensor's pair, replaces the end of [t1 = old_outputs] by the end of [t0 = new_outputs]
        # The end of the tensors in [ts1 = old_outputs] are left dangling
        _ = graph_editor.reroute_ts(ts0=new_outputs,
                                    ts1=old_outputs,
                                    can_modify=consumers)

        # Add the new layer to the database
        self._compressible_layers[id(layer_a.module)] = layer_a
        self._compressible_layers[id(layer_b.module)] = layer_b

        # Remove the the layer being replaced from the database
        del self._compressible_layers[id(layer_to_replace.module)]
Beispiel #17
0
 def test_compatibility(self):
     with self.assertRaises(ValueError):
         ge.reroute_ts([self.a0, self.b0], [self.a2, self.b2])
def gradients(ys, xs,   # pylint: disable: too-many-statements, too-many-branches
              grad_ys=None, checkpoints='collection', **kwargs):
    '''
    Authors: Tim Salimans & Yaroslav Bulatov

    memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory
    Cost" by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)

    ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
    (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)

    'checkpoints' can either be
        - a list consisting of tensors from the forward pass of the neural net
          that we should re-use when calculating the gradients in the backward pass
          all other tensors that do not appear in this list will be re-computed
        - a string specifying how this list should be determined. currently we support
            - 'speed':  checkpoint all outputs of convolutions and matmuls. these ops are usually
                        the most expensive, so checkpointing them maximizes the running speed
                        (this is a good option if nonlinearities, concats, batchnorms, etc are
                        taking up a lot of memory)
            - 'memory': try to minimize the memory usage
                        (currently using a very simple strategy that identifies a number of
                        bottleneck tensors in the graph to checkpoint)
            - 'collection': look for a tensorflow collection named 'checkpoints', which holds the
                            tensors to checkpoint
    '''

    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys],
                                       inclusive=True)

    debug_print("bwd_ops: {}".format(bwd_ops))

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    debug_print("fwd_ops: {}".format(fwd_ops))

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if op not in xs_ops]
    fwd_ops = [op for op in fwd_ops if '/assign' not in op.name]
    fwd_ops = [op for op in fwd_ops if '/Assign' not in op.name]
    fwd_ops = [op for op in fwd_ops if '/read' not in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        if checkpoints == 'collection':
            checkpoints = tf.get_collection('checkpoints')

        elif checkpoints == 'speed':
            # checkpoint all expensive ops to maximize running speed
            checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul')

        elif checkpoints == 'memory':

            # remove very small tensors and some weird ops
            def fixdims(t):  # tf.Dimension values are not compatible with int, convert manually
                try:
                    return [int(e if e.value is not None else 64) for e in t]
                except:
                    return [0]  # unknown shape
            ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE]
            ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
            ts_all = [t for t in ts_all if 'entropy' not in t.name]
            ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
            ts_all = [t for t in ts_all if 'Switch' not in t.name]
            ts_all = [t for t in ts_all if 'dropout' not in t.name]
            # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
            ts_all = [t for t in ts_all if 'Cast' not in t.name]

            # filter out all tensors that are inputs of the backward graph
            with util.capture_ops() as bwd_ops:
                tf_gradients(ys, xs, grad_ys, **kwargs)

            bwd_inputs = [t for op in bwd_ops for t in op.inputs]
            # list of tensors in forward graph that is in input to bwd graph
            ts_filtered = list(set(bwd_inputs).intersection(ts_all))
            debug_print("Using tensors {}".format(ts_filtered))

            # try two slightly different ways of getting bottlenecks tensors
            # to checkpoint
            for ts in [ts_filtered, ts_all]:

                # get all bottlenecks in the graph
                bottleneck_ts = []
                for t in ts:
                    b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops))
                    f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops))
                    # check that there are not shortcuts
                    b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all)
                    f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all)
                    if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all):
                        bottleneck_ts.append(t)  # we have a bottleneck!
                    else:
                        debug_print("Rejected bottleneck candidate and ops {}".format(
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp))))

                # success? or try again without filtering?
                if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)):  # enough bottlenecks found!
                    break

            if not bottleneck_ts:
                raise Exception('unable to find bottleneck tensors! please provide checkpoint '
                                'nodes manually, or use checkpoints="speed".')

            # sort the bottlenecks
            bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops)
            sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]

            # save an approximately optimal number ~ sqrt(N)
            N = len(ts_filtered)
            if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
                checkpoints = sorted_bottlenecks
            else:
                step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
                checkpoints = sorted_bottlenecks[step::step]

        else:
            raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,))

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: {}".format(checkpoints))
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print("Warning, some input nodes are also checkpoint nodes: {}".format(
            xs_intersect_checkpoints))
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print("ys: {}, checkpoints:{}, intersect: {}".format(
        ys, checkpoints, ys_intersect_checkpoints))
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print("Warning, some output nodes are also checkpoints nodes: {}".format(
            format_ops(ys_intersect_checkpoints)))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name+"_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints, within_ops=fwd_ops)
    debug_print("Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
        len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints))
    debug_print("ops_to_copy = {}".format(ops_to_copy))
    debug_print("Processing list {}".format(ys))
    _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired {} in place of {} restricted to {}".format(
        checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops))

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs)
    debug_print("Got gradients {}".format(dv))
    debug_print("for %s", copied_ys)
    debug_print("with respect to {}".format(boundary+xs))

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {r: dr for r, dr in zip(checkpoints_disconnected.keys(),
                                            dv[:len(checkpoints_disconnected)])}
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list {}".format(ts))
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print("Found {} ops to copy within {}, seed {}, stop_at {}".format(
            len(ops_to_copy), fwd_ops, [r.op for r in ts], checkpoints_other))
        debug_print("ops_to_copy = {}".format(ops_to_copy))
        if not ops_to_copy:  # we're done!
            break
        _, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied {} to {}".format(ops_to_copy, copied_ops))
        ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops)
        debug_print("Rewired %s in place of %s restricted to %s",
                    checkpoints_disconnected_other, checkpoints_other, copied_ops)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other+xs,
                          grad_ys=substitute_backprops, **kwargs)
        debug_print("Got gradients {}".format(dv))
        debug_print("for {}".format(boundary))
        debug_print("with respect to {}".format(checkpoints_disconnected_other+xs))
        debug_print("with boundary backprop substitutions {}".format(substitute_backprops))

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(var_x):
            if not isinstance(var_x, tf.IndexedSlices):
                return var_x
            assert var_x.dense_shape is not None, \
                "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = var_x.indices
            while indices.shape.ndims < var_x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, var_x.values, var_x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
Beispiel #19
0
    def _InsertQuantOp(
        self,
        context,
        producer,
        consumers,
        name,
        moving_avg=True,
        init_min=-6.0,
        init_max=6.0,
        delay_requested=True,
        bits=8,
        narrow_range=False,
    ):
        """Inserts a quant op between a producer op and (multiple) consumer ops.

    Args:
      context: Context where producer and consumer operations are nested.
      producer: Producer operation of the pairs where quantization will be
        inserted.
      consumers: Consumer operations of the pairs.
      name: Name for the new quantization op within the context.
      moving_avg: Specifies whether to use exponential moving average or just
        the last value seen.
      init_min: Starting minimum value for the new quantization op.
      init_max: Starting maximum value for the new quantization op.
      delay_requested: If true, implement quantization delay where needed.
        False value explicitly disables delay quantization everywhere.
      bits: Number of bits to use for quantization, must be between 2 and 8.
      narrow_range: Whether to use the narrow quantization range
        [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    Raises:
      ValueError: When producer operation is not directly connected to the
        consumer operation.
    """
        scope = context + '/' + name
        inputs = producer.outputs[0]
        if moving_avg:
            quant = (quant_ops.MovingAvgQuantize(
                inputs,
                init_min=init_min,
                init_max=init_max,
                ema_decay=self.ema_decay,
                is_training=self.is_training,
                num_bits=bits,
                narrow_range=narrow_range,
                updates_collection=_UPDATE_QUANT_OPS,
                vars_collection=self.vars_collection,
                scope=scope))
        else:
            quant = (quant_ops.LastValueQuantize(
                inputs,
                init_min=init_min,
                init_max=init_max,
                is_training=self.is_training,
                num_bits=bits,
                narrow_range=narrow_range,
                updates_collection=_UPDATE_QUANT_OPS,
                vars_collection=self.vars_collection,
                scope=scope))

        if delay_requested and self.quant_delay and self.quant_delay > 0:
            activate_quant = math_ops.greater_equal(
                training_util.get_or_create_global_step(),
                self.quant_delay,
                name=scope + '/activate_quant')
            quant = control_flow_ops.cond(activate_quant,
                                          lambda: quant,
                                          lambda: inputs,
                                          name=scope + '/delayed_quant')

        nodes_modified_count = graph_editor.reroute_ts([quant], [inputs],
                                                       can_modify=consumers)
        if nodes_modified_count != len(consumers):
            raise ValueError(
                'Some inputs not quantized for ops: [%s]' %
                ', '.join([consumer.name for consumer in consumers]))
Beispiel #20
0
    def _split_conv_layer(self, sess, svd_ranks, attr, op_name, bias_op_name=None):
        """
        Split a given conv layer given a rank
        :param sess: tf.compat.v1.Session
        :param svd_ranks: Rank to split the layer with (two ranks in case of SSVD)
        :param attr: Reference to the corresponding layer attribute
        :param op_name: Name of the op to split
        :param bias_op_name: Name of the corresponding bias op (if any)
        :return: None
        """
        # pylint: disable=too-many-statements,too-many-branches,too-many-locals

        logger.info('Splitting conv op: %s', op_name)

        # Retrieve the op(s) from the current graph
        op = sess.graph.get_operation_by_name(op_name)

        bias_op = None
        if bias_op_name:
            bias_op = sess.graph.get_operation_by_name(bias_op_name)

        # Create new 'conv_a' layer
        pad_mode = op.get_attr('padding')
        data_format = op.get_attr('data_format').decode('utf-8')
        strides = op.get_attr('strides')

        # Print current conv weight shape
        query = core.OpQuery(sess.graph)
        w_shape = query.get_weights_for_op(op).get_shape().as_list()
        logger.debug('Original %s weight shape: %s', op.name, str(w_shape))
        split_weights, weight_sizes = [], []
        split_biases, bias_sizes = [], []

        # TF weights are in [H,W,I,O] order. We must reshape the split weights to SVD format [O,I,H,W]
        # and then transpose back
        # Conv a weights are: [1, 1, w_shape[2], svd_ranks[0]]
        split_conv_a_w_shape = (svd_ranks[0], w_shape[2], 1, 1)
        conv_a_weights = np.zeros(split_conv_a_w_shape)     # transpose(2,3,1,0)
        split_weights.append(conv_a_weights.flatten().tolist())
        weight_sizes.append(conv_a_weights.size)
        if bias_op:
            conv_a_bias = np.zeros(svd_ranks[0])
            split_biases.append(conv_a_bias.flatten().tolist())
            bias_sizes.append(conv_a_bias.size)

        num_filters = w_shape[3]
        if len(svd_ranks) >= 2 and attr.mode == pymo.TYPE_SUCCESSIVE:
            # Output channels = output_rank (s)
            num_filters = svd_ranks[1]

        # Conv b weights are: [w_shape[0],w_shape[1],svd_ranks[0],num_filters]
        split_conv_b_w_shape = (num_filters, svd_ranks[0], w_shape[0], w_shape[1])
        conv_b_weights = np.zeros(split_conv_b_w_shape)
        conv_b_bias = np.zeros(num_filters)
        split_weights.append(conv_b_weights.flatten().tolist())
        weight_sizes.append(conv_b_weights.size)
        if bias_op:
            split_biases.append(conv_b_bias.flatten().tolist())
            bias_sizes.append(conv_b_bias.size)

        # Only create a third conv layer when performing successive SVD
        if len(svd_ranks) >= 2 and attr.mode == pymo.TYPE_SUCCESSIVE:
            # Conv c weights are: [1,1,num_filters,w_shape[3]]
            split_conv_c_w_shape = (w_shape[3], num_filters, 1, 1)
            conv_c_weights = np.zeros(split_conv_c_w_shape)
            conv_c_bias = np.zeros(w_shape[3])
            split_weights.append(conv_c_weights.flatten().tolist())
            weight_sizes.append(conv_c_weights.size)
            if bias_op:
                split_biases.append(conv_c_bias.flatten().tolist())
                bias_sizes.append(conv_c_bias.size)

        # Split the weights and biases according to the number of layers and ranks
        split_weights = self._svd.SplitLayerWeights(op.name, split_weights, weight_sizes, svd_ranks)
        split_biases = self._svd.SplitLayerBiases(op.name, split_biases, bias_sizes, svd_ranks)
        if split_weights:
            conv_a_name = op.name+'_a'
            conv_a_weights = np.array(split_weights[0]).reshape(split_conv_a_w_shape).transpose(2, 3, 1, 0)
            conv_a_w = tf.Variable(initial_value=conv_a_weights, name=conv_a_name+'_w', dtype=tf.float32)
            logger.debug('%s weight shape: %s', conv_a_name, str(conv_a_weights.shape))

            # Create conv_a using default strides (1,1)
            # pylint: disable=no-member
            conv_acts = tf.nn.conv2d(op.inputs[0], conv_a_w, strides=[1, 1, 1, 1], data_format=data_format,
                                     padding=pad_mode, name=op.name+'_a')  # dilation_rate=dilation_rate
            if bias_op:
                conv_a_bias = tf.Variable(initial_value=split_biases[0], name=conv_a_name+'_bias', dtype=tf.float32)
                conv_acts = conv_acts + conv_a_bias     # tf.nn.bias_add(conv_acts, split_biases[0])

        if len(split_weights) > 1:
            # Create conv_b
            conv_b_name = op.name+'_b'
            conv_b_weights = np.array(split_weights[1]).reshape(split_conv_b_w_shape).transpose(2, 3, 1, 0)
            conv_b_w = tf.Variable(initial_value=conv_b_weights, name=conv_b_name+'_w', dtype=tf.float32)
            logger.debug('%s weight shape: %s', conv_b_name, str(conv_b_weights.shape))

            # pylint: disable=no-member
            conv_acts = tf.nn.conv2d(conv_acts, conv_b_w, strides=strides, data_format=data_format, padding=pad_mode, name=conv_b_name) #dilation_rate=dilation_rate
            if bias_op:
                conv_b_bias = tf.Variable(initial_value=split_biases[1], name=conv_b_name+'_bias', dtype=tf.float32)
                conv_acts = conv_acts + conv_b_bias     # tf.nn.bias_add(conv_acts, split_biases[1])
        ratio = self._compute_per_layer_compression_ratio([conv_a_w.shape, conv_b_w.shape], conv_acts.shape, w_shape, "Conv2D")
        # Only create a third conv layer when performing successive SVD
        if len(split_weights) > 2 and len(svd_ranks) >= 2 and attr.mode == pymo.TYPE_SUCCESSIVE:
            # Create conv_c, using default strides (1,1)
            conv_c_name = op.name+'_c'
            conv_c_weights = np.array(split_weights[2]).reshape(split_conv_c_w_shape).transpose(2, 3, 1, 0)
            conv_c_w = tf.Variable(initial_value=conv_c_weights, name=conv_c_name+'_w', dtype=tf.float32)
            logger.debug('%s weight shape: %s', conv_c_name, str(conv_c_weights.shape))

            # pylint: disable=no-member
            conv_acts = tf.nn.conv2d(conv_acts, conv_c_w, strides=[1, 1, 1, 1], data_format=data_format,
                                     padding=pad_mode, name=conv_c_name)
            if bias_op:
                conv_c_bias = tf.Variable(initial_value=split_biases[2], name=conv_c_name+'_bias', dtype=tf.float32)
                conv_acts = conv_acts + conv_c_bias     # tf.nn.bias_add(conv_acts, split_biases[2])

        consumers = []
        rerouted_inputs = [bias_op.outputs[0]] if bias_op else [op.outputs[0]]
        for inp in rerouted_inputs:
            for consumer in inp.consumers():
                consumers.append(consumer)
        _ = ge.reroute_ts(conv_acts, rerouted_inputs, can_modify=consumers)

        return ratio
Beispiel #21
0
def replace_input(op, old_input, new_input):
  """Replaces old input with new input in op"""
  ge.reroute_ts([new_input], [old_input], can_modify=[op])
Beispiel #22
0
def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
    """Finds fused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, true if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
    for match in _FindFusedBatchNorms(graph):
        scope, sep, _ = match.layer_op.name.rpartition('/')
        # Make sure new ops are added to `graph` and put on the same device as
        # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
        # named `scope`. Otherwise, TF creates a unique scope whose name starts with
        # `scope`.
        with graph.as_default(), graph.name_scope(scope + sep):
            with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
                # new weights = old weights * gamma / sqrt(variance + epsilon)
                # new biases = -mean * gamma / sqrt(variance + epsilon) + beta
                multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
                    match.variance_tensor + match.bn_op.get_attr('epsilon'))
                bias_tensor = math_ops.subtract(match.beta_tensor,
                                                match.mean_tensor *
                                                multiplier_tensor,
                                                name='bias')

                correction_scale, correction_recip, correction_offset = None, None, None
                if is_training:
                    correction_scale, correction_recip, correction_offset = (
                        _ComputeBatchNormCorrections(
                            context='',
                            match=match,
                            freeze_batch_norm_delay=freeze_batch_norm_delay,
                            fused_batch_norm=True))
                # The shape of depthwise weights is different, so we need to reshape the
                # multiplier_tensor to ensure that the scaled_weight_tensor has the
                # expected shape.
                weights = match.weight_tensor
                if match.layer_op.type == 'DepthwiseConv2dNative':
                    new_shape = [
                        match.weight_tensor.get_shape().as_list()[2],
                        match.weight_tensor.get_shape().as_list()[3]
                    ]
                    multiplier_tensor = array_ops.reshape(multiplier_tensor,
                                                          new_shape,
                                                          name='scale_reshape')

                    if correction_scale is not None:
                        correction_scale = array_ops.reshape(
                            correction_scale,
                            new_shape,
                            name='correction_reshape')

            if correction_scale is not None:
                weights = math_ops.multiply(correction_scale,
                                            weights,
                                            name='correction_mult')

            scaled_weight_tensor = math_ops.multiply(weights,
                                                     multiplier_tensor,
                                                     name='mul_fold')
            new_layer_tensor = _CloneWithNewOperands(match.layer_op,
                                                     match.input_tensor,
                                                     scaled_weight_tensor,
                                                     match.batch_to_space_op)

            if correction_recip is not None:
                new_layer_tensor = math_ops.multiply(correction_recip,
                                                     new_layer_tensor,
                                                     name='post_conv_mul')
                new_layer_tensor = math_ops.add(new_layer_tensor,
                                                (correction_offset),
                                                'correction_add')

            bias_add_tensor = math_ops.add(new_layer_tensor,
                                           bias_tensor,
                                           name='add_fold')

            nodes_modified_count = graph_editor.reroute_ts(
                bias_add_tensor, match.output_tensor)
            if nodes_modified_count == 0:
                raise ValueError(
                    'Folding batch norms failed, %s had no outputs.' %
                    match.output_tensor.name)
Beispiel #23
0
 def test_compatibility(self):
   with self.assertRaises(ValueError):
     ge.reroute_ts([self.a0, self.b0], [self.a2, self.b2])
Beispiel #24
0
 def replace_varible(self, new_list):
     assert len(new_list) == len(self._var)
     for new_var, modify_ops in zip(new_list, self._modify_consumer):
         for op, v_replaced in modify_ops:
             ge.reroute_ts(new_var, v_replaced, can_modify=[op])
Beispiel #25
0
  def __build_pruned_train_model(self, path=None, finetune=False): # pylint: disable=too-many-locals
    ''' build a training model from pruned model '''
    if path is None:
      path = FLAGS.save_path

    with tf.Graph().as_default():
      config = tf.ConfigProto()
      config.gpu_options.visible_device_list = str(# pylint: disable=no-member
        mgw.local_rank() if FLAGS.enbl_multi_gpu else 0)
      self.sess_train = tf.Session(config=config)
      self.saver_train = tf.train.import_meta_graph(path + '.meta')
      self.saver_train.restore(self.sess_train, path)
      logits = tf.get_collection('logits')[0]
      train_images = tf.get_collection('train_images')[0]
      train_labels = tf.get_collection('train_labels')[0]
      mem_images = tf.get_collection('mem_images')[0]
      mem_labels = tf.get_collection('mem_labels')[0]

      self.sess_train.close()

      graph_editor.reroute_ts(train_images, mem_images)
      graph_editor.reroute_ts(train_labels, mem_labels)

      self.sess_train = tf.Session(config=config)
      self.saver_train.restore(self.sess_train, path)

      trainable_vars = self.trainable_vars
      loss, accuracy = self.calc_loss(train_labels, logits, trainable_vars)
      self.accuracy_keys = list(accuracy.keys())

      if FLAGS.enbl_dst:
        logits_dst = self.learner_dst.calc_logits(self.sess_train, train_images)
        loss += self.learner_dst.calc_loss(logits, logits_dst)

      tf.summary.scalar('loss', loss)
      for key in accuracy.keys():
        tf.summary.scalar(key, accuracy[key])
      self.summary_op = tf.summary.merge_all()

      global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32, trainable=False)
      self.global_step = global_step
      lrn_rate, self.nb_iters_train = setup_lrn_rate(
        self.global_step, self.model_name, self.dataset_name)

      if finetune and not FLAGS.cp_retrain:
        mom_optimizer = tf.train.AdamOptimizer(FLAGS.cp_lrn_rate_ft)
        self.log_op = [tf.constant(FLAGS.cp_lrn_rate_ft), loss, list(accuracy.values())]
      else:
        mom_optimizer = tf.train.MomentumOptimizer(lrn_rate, FLAGS.momentum)
        self.log_op = [lrn_rate, loss, list(accuracy.values())]

      if FLAGS.enbl_multi_gpu:
        optimizer = mgw.DistributedOptimizer(mom_optimizer)
      else:
        optimizer = mom_optimizer
      grads_origin = optimizer.compute_gradients(loss, trainable_vars)
      grads_pruned, masks = self.__calc_grads_pruned(grads_origin)


      with tf.control_dependencies(self.update_ops):
        self.train_op = optimizer.apply_gradients(grads_pruned, global_step=global_step)

      self.sm_writer.add_graph(tf.get_default_graph())
      self.train_init_op = \
        tf.initialize_variables(mom_optimizer.variables() + [global_step] + masks)

      if FLAGS.enbl_multi_gpu:
        self.bcast_op = mgw.broadcast_global_variables(0)
Beispiel #26
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False,
                   producer_scope=None,
                   consumer_scope=None):
  """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    producer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when the producer is in this scope.
    consumer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when all the consumers are in this scope.
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
  if producer_scope and not producer.name.startswith(producer_scope):
    logging.info(
        '_InsertQuantOp ignores context="%s" name="%s" '
        'because producer "%s" is not in scope "%s"',
        context, name, producer.name, producer_scope)
    return

  if consumer_scope:
    consumers_in_scope = []
    for consumer in consumers:
      if consumer.name.startswith(consumer_scope):
        consumers_in_scope.append(consumer)
      else:
        logging.info(
            '_InsertQuantOp context="%s" name="%s" ignores '
            'consumer "%s" because it is not in scope "%s"',
            context, name, consumer.name, consumer_scope)
        return
    consumers = consumers_in_scope

  name_prefix = _AddContextToName(context, name)
  # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
  # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
  # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
  # breaks things later.
  name_scope = ops.get_name_scope()
  if name_scope:
    name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')

  inputs = producer.outputs[0]
  # Prevent ops from being quantized multiple times. Bypass ops can sometimes
  # overlap between multiple matches, so we need to ensure that we don't
  # add duplicate FakeQuant operations.
  fake_quant_ops = set([
      'FakeQuantWithMinMaxVars',
      'FakeQuantWithMinMaxArgs'
  ])
  if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
    return

  if moving_avg:
    quant = (
        quant_ops.MovingAvgQuantize(
            inputs,
            init_min=init_min,
            init_max=init_max,
            ema_decay=ema_decay,
            is_training=is_training,
            num_bits=bits,
            narrow_range=narrow_range,
            vars_collection=vars_collection,
            name_prefix=name_prefix))
  else:
    quant = (
        quant_ops.LastValueQuantize(
            inputs,
            init_min=init_min,
            init_max=init_max,
            is_training=is_training,
            num_bits=bits,
            narrow_range=narrow_range,
            vars_collection=vars_collection,
            name_prefix=name_prefix))

  if quant_delay and quant_delay > 0:
    activate_quant = math_ops.greater_equal(
        common.CreateOrGetQuantizationStep(),
        quant_delay,
        name=name_prefix + '/activate_quant')
    quant = control_flow_ops.cond(
        activate_quant,
        lambda: quant,
        lambda: inputs,
        name=name_prefix + '/delayed_quant')

  if consumers:
    tensors_modified_count = graph_editor.reroute_ts(
        [quant], [inputs], can_modify=consumers)
    # Some operations can have multiple output tensors going to the same
    # consumer. Since consumers is a set, we need to ensure that
    # tensors_modified_count is greater than or equal to the length of the set
    # of consumers.
    if tensors_modified_count < len(consumers):
      raise ValueError('No inputs quantized for ops: [%s]' % ', '.join(
          [consumer.name for consumer in consumers]))
Beispiel #27
0
    X_train = X_train_
    Y_train = Y_train_

    # Reroute tensors to the location of the data on the GPU, add noise and image augmentation (random crops)
    train_idx = tf.placeholder(tf.int32, shape=[None], name='train_idx')
    tf.add_to_collection('placeholders', train_idx)
    input_noise_magnitude = tf.placeholder_with_default(
        0.0, shape=[], name='input_noise_magnitude')
    tf.add_to_collection('placeholders', input_noise_magnitude)
    X_train = tf.gather(X_train, train_idx)
    X_train = tf.pad(X_train, paddings=[[0, 0], [3, 3], [3, 3], [0, 0]])
    X_train = u.random_crop(X_train, [28, 28, 1])
    X_train += input_noise_magnitude * tf.random_normal(tf.shape(X_train),
                                                        dtype=tf.float32)
    Y_train = tf.gather(Y_train, train_idx)
    ge.reroute_ts([X_train, Y_train], [X, labels])

    # Define tensor to compute accuracy and confidence metrics
    Y_pred = tf.argmax(Y, axis=-1, output_type=tf.int32)
    prob = tf.nn.softmax(Y, dim=-1)
    Y_pred_prob = tf.reduce_max(prob, axis=-1)
    A = tf.cast(tf.equal(Y_train, Y_pred), tf.float32)
    acc = tf.reduce_mean(A)
    conf = tf.reduce_mean(Y_pred_prob * A)

    # Start the TF session and load variables
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION
    with tf.Session(config=config) as sess:
        if RESET_PARAMETERS:
Beispiel #28
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    name_prefix = _AddContextToName(context, name)
    # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    # breaks things later.
    name_prefix = common.DropStringPrefix(name_prefix,
                                          ops.get_name_scope() + '/')

    inputs = producer.outputs[0]
    if moving_avg:
        quant = (quant_ops.MovingAvgQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             ema_decay=ema_decay,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))
    else:
        quant = (quant_ops.LastValueQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))

    if quant_delay and quant_delay > 0:
        activate_quant = math_ops.greater_equal(
            common.CreateOrGetQuantizationStep(),
            quant_delay,
            name=name_prefix + '/activate_quant')
        quant = control_flow_ops.cond(activate_quant,
                                      lambda: quant,
                                      lambda: inputs,
                                      name=name_prefix + '/delayed_quant')

    nodes_modified_count = graph_editor.reroute_ts([quant], [inputs],
                                                   can_modify=consumers)
    if nodes_modified_count != len(consumers):
        raise ValueError('Some inputs not quantized for ops: [%s]' %
                         ', '.join([consumer.name for consumer in consumers]))
Beispiel #29
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.GLOBAL_VARIABLES,
                   narrow_range=False,
                   producer_scope=None,
                   consumer_scope=None):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context where producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    producer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when the producer is in this scope.
    consumer_scope: The restriction of producer scope. If not None, the new op
      will be inserted only when all the consumers are in this scope.
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
    if producer_scope and not producer.name.startswith(producer_scope):
        logging.info(
            '_InsertQuantOp ignores context="%s" name="%s" '
            'because producer "%s" is not in scope "%s"', context, name,
            producer.name, producer_scope)
        return

    if consumer_scope:
        consumers_in_scope = []
        for consumer in consumers:
            if consumer.name.startswith(consumer_scope):
                consumers_in_scope.append(consumer)
            else:
                logging.info(
                    '_InsertQuantOp context="%s" name="%s" ignores '
                    'consumer "%s" because it is not in scope "%s"', context,
                    name, consumer.name, consumer_scope)
                return
        consumers = consumers_in_scope

    name_prefix = _AddContextToName(context, name)
    # This is needed on TPU where name_scope == 'TPUReplicate/loop', and
    # name_prefix starts with 'TPUReplicate/loop/'; without dropping it
    # variables are created as TPUReplicate/loop/TPUReplicate/loop/..., which
    # breaks things later.
    name_scope = ops.get_name_scope()
    if name_scope:
        name_prefix = common.DropStringPrefix(name_prefix, name_scope + '/')

    inputs = producer.outputs[0]
    # Prevent ops from being quantized multiple times. Bypass ops can sometimes
    # overlap between multiple matches, so we need to ensure that we don't
    # add duplicate FakeQuant operations.
    fake_quant_ops = set(
        ['FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs'])
    if fake_quant_ops.intersection(set([c.type for c in inputs.consumers()])):
        return

    if moving_avg:
        quant = (quant_ops.MovingAvgQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             ema_decay=ema_decay,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))
    else:
        quant = (quant_ops.LastValueQuantize(inputs,
                                             init_min=init_min,
                                             init_max=init_max,
                                             is_training=is_training,
                                             num_bits=bits,
                                             narrow_range=narrow_range,
                                             vars_collection=vars_collection,
                                             name_prefix=name_prefix))

    if quant_delay and quant_delay > 0:
        activate_quant = math_ops.greater_equal(
            common.CreateOrGetQuantizationStep(),
            quant_delay,
            name=name_prefix + '/activate_quant')
        quant = control_flow_ops.cond(activate_quant,
                                      lambda: quant,
                                      lambda: inputs,
                                      name=name_prefix + '/delayed_quant')

    if consumers:
        tensors_modified_count = graph_editor.reroute_ts([quant], [inputs],
                                                         can_modify=consumers)
        # Some operations can have multiple output tensors going to the same
        # consumer. Since consumers is a set, we need to ensure that
        # tensors_modified_count is greater than or equal to the length of the set
        # of consumers.
        if tensors_modified_count < len(consumers):
            raise ValueError(
                'No inputs quantized for ops: [%s]' %
                ', '.join([consumer.name for consumer in consumers]))
Beispiel #30
0
def _InsertQuantOp(context,
                   name,
                   producer,
                   consumers,
                   is_training,
                   moving_avg=True,
                   init_min=-6.0,
                   init_max=6.0,
                   bits=8,
                   ema_decay=0.999,
                   quant_delay=None,
                   vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
                   narrow_range=False):
  """Inserts a quant op between a producer op and (multiple) consumer ops.

  Args:
    context: Context w,here producer and consumer operations are nested.
    name: Name for the new quantization op within the context.
    producer: Producer operation of the pairs where quantization will be
      inserted.
    consumers: Consumer operations of the pairs.
    is_training: Whether quantizing training graph or eval graph.
    moving_avg: Specifies whether to use exponential moving average or just
      the last value seen.
    init_min: Starting minimum value for the new quantization op.
    init_max: Starting maximum value for the new quantization op.
    bits: Number of bits to use for quantization, must be between 2 and 8.
    ema_decay: (Optional) Float, EMA decay parameter.  EMA is used to update
      quantization intervals for quantizing activations (see here about EMA:
      https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average).
    quant_delay: (Optional, default None) Int, count of global steps for which
      to delay quantization.  This helps weights stabilize at the start of
      training.
    vars_collection: (Optional) Collection where to store the variables for
      quantization interval ends.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^bits - 1] or wide range [0; 2^bits - 1].
  Raises:
    ValueError: When producer operation is not directly connected to the
      consumer operation.
  """
  name_prefix = _AddContextToName(context, name)
  inputs = producer.outputs[0]
  if moving_avg:
    quant = (
        quant_ops.MovingAvgQuantize(
            inputs,
            init_min=init_min,
            init_max=init_max,
            ema_decay=ema_decay,
            is_training=is_training,
            num_bits=bits,
            narrow_range=narrow_range,
            vars_collection=vars_collection,
            name_prefix=name_prefix))
  else:
    quant = (
        quant_ops.LastValueQuantize(
            inputs,
            init_min=init_min,
            init_max=init_max,
            is_training=is_training,
            num_bits=bits,
            narrow_range=narrow_range,
            vars_collection=vars_collection,
            name_prefix=name_prefix))

  if quant_delay and quant_delay > 0:
    activate_quant = math_ops.greater_equal(
        common.CreateOrGetQuantizationStep(),
        quant_delay,
        name=name_prefix + '/activate_quant')
    quant = control_flow_ops.cond(
        activate_quant,
        lambda: quant,
        lambda: inputs,
        name=name_prefix + '/delayed_quant')

  nodes_modified_count = graph_editor.reroute_ts(
      [quant], [inputs], can_modify=consumers)
  if nodes_modified_count != len(consumers):
    raise ValueError('Some inputs not quantized for ops: [%s]' % ', '.join(
        [consumer.name for consumer in consumers]))
Beispiel #31
0
    def _insert_activation_quantization_ops(self, ops, collect_stats=True):

        # pylint: disable=too-many-locals
        encodings = self._activation_encodings['encodings']
        if encodings:
            self._log.info("Using fixed activation encodings")

        # Add all the activation quantization operations
        for op in ops:
            op_name = self._get_quantized_name(op.name)
            self._log.info("Adding quantization activation %s for %s", op_name,
                           op.name)

            # When fixed encodings are set we aren't collecting stats, so use the collected encodings
            enc_mins, enc_maxs = [], []
            config = int(libpytrext.config_type.CONFIG_TYPE_UPDATE_STATS)
            if not collect_stats:
                if not encodings:
                    raise RuntimeError(
                        'No activation encodings recorded for activation quantization ops!'
                    )
                if op_name not in encodings:
                    raise RuntimeError("Can't find activation encoding for: " +
                                       op_name)
                config = int(
                    libpytrext.config_type.CONFIG_TYPE_Q_DQ_ACTIVATIONS)
                encoding_tuple = encodings[op_name]
                enc_mins = encoding_tuple[0]
                enc_maxs = encoding_tuple[1]
                self._log.debug("Using min,max encodings: %s,%s", enc_mins,
                                enc_maxs)

            # Add the new activation quantization op and reroute the outputs from the producer node to the
            # quantization op and the quantization outputs to the consumer(s)
            inputs = [output for output in op.outputs]
            num_tensors = len(inputs)
            consumers = []
            for inp in inputs:
                for consumer in inp.consumers():
                    if 'gradients' not in consumer.name:
                        consumers.append(consumer)

            # CPU device assignment for QcQuantize op
            if not self._gpu:
                with tf.device('/cpu:0'):
                    q_op_out = _qcops.qc_quantize_deprecated(
                        name=op_name,
                        op_name=op_name,
                        training_in_progress=self._is_train_variable,
                        config=config,
                        bitwidth=self._bw_acts,
                        in_tensors=inputs,
                        fixed_enc_mins=enc_mins,
                        fixed_enc_maxs=enc_maxs,
                        quant_mode=self._quant_mode_str,
                        round_mode=self._round_mode_str,
                        num_tensors=num_tensors)

            # GPU device assignment for QcQuantize op
            else:
                q_op_out = _qcops.qc_quantize_deprecated(
                    name=op_name,
                    op_name=op_name,
                    training_in_progress=self._is_train_variable,
                    config=config,
                    bitwidth=self._bw_acts,
                    in_tensors=inputs,
                    fixed_enc_mins=enc_mins,
                    fixed_enc_maxs=enc_maxs,
                    quant_mode=self._quant_mode_str,
                    round_mode=self._round_mode_str,
                    num_tensors=num_tensors)
            qc_outputs = [
                tf_ops.convert_to_tensor(q_op_out[i][0])
                for i in range(len(inputs))
            ]
            num_rerouted_outputs = ge.reroute_ts(qc_outputs,
                                                 inputs,
                                                 can_modify=consumers)
            if num_rerouted_outputs != len(consumers):
                raise ValueError('Failed to map ' + str(len(consumers)) +
                                 ' quantization output(s). Only mapped ' +
                                 str(num_rerouted_outputs))
            # Save the activation quantization op name for later
            if collect_stats:
                self._quant_act_ops.append(op_name)
Beispiel #32
0
  def _InsertQuantOp(
      self,
      context,
      producer,
      consumers,
      name,
      moving_avg=True,
      init_min=-6.0,
      init_max=6.0,
      delay_requested=True,
      bits=8,
      narrow_range=False,):
    """Inserts a quant op between a producer op and (multiple) consumer ops.

    Args:
      context: Context where producer and consumer operations are nested.
      producer: Producer operation of the pairs where quantization will be
        inserted.
      consumers: Consumer operations of the pairs.
      name: Name for the new quantization op within the context.
      moving_avg: Specifies whether to use exponential moving average or just
        the last value seen.
      init_min: Starting minimum value for the new quantization op.
      init_max: Starting maximum value for the new quantization op.
      delay_requested: If true, implement quantization delay where needed.
        False value explicitly disables delay quantization everywhere.
      bits: Number of bits to use for quantization, must be between 2 and 8.
      narrow_range: Whether to use the narrow quantization range
        [1; 2^bits - 1] or wide range [0; 2^bits - 1].
    Raises:
      ValueError: When producer operation is not directly connected to the
        consumer operation.
    """
    scope = context + '/' + name
    inputs = producer.outputs[0]
    if moving_avg:
      quant = (quant_ops.MovingAvgQuantize(
          inputs,
          init_min=init_min,
          init_max=init_max,
          ema_decay=self.ema_decay,
          is_training=self.is_training,
          num_bits=bits,
          narrow_range=narrow_range,
          updates_collection=_UPDATE_QUANT_OPS,
          vars_collection=self.vars_collection,
          scope=scope))
    else:
      quant = (quant_ops.LastValueQuantize(
          inputs,
          init_min=init_min,
          init_max=init_max,
          is_training=self.is_training,
          num_bits=bits,
          narrow_range=narrow_range,
          updates_collection=_UPDATE_QUANT_OPS,
          vars_collection=self.vars_collection,
          scope=scope))

    if delay_requested and self.quant_delay and self.quant_delay > 0:
      activate_quant = math_ops.greater_equal(
          training_util.get_or_create_global_step(),
          self.quant_delay,
          name=scope + '/activate_quant')
      quant = control_flow_ops.cond(
          activate_quant,
          lambda: quant,
          lambda: inputs,
          name=scope + '/delayed_quant')

    nodes_modified_count = graph_editor.reroute_ts(
        [quant], [inputs], can_modify=consumers)
    if nodes_modified_count != len(consumers):
      raise ValueError('Some inputs not quantized for ops: [%s]' %
                       ', '.join([consumer.name for consumer in consumers]))
Beispiel #33
0
    def _split_fc_layer(self, sess, svd_ranks, op_name, bias_op_name=None):
        """
        Split a given conv layer given a rank
        :param sess: tf.compat.v1.Session
        :param svd_ranks: Rank to split the layer with (two ranks in case of SSVD)
        :param op_name: Name of the op to split
        :param bias_op_name: Name of the corresponding bias op (if any)
        :return: None
        """
        # pylint: disable=too-many-statements, too-many-locals

        logger.info('Splitting fully connected op: %s', op_name)

        # Retrieve the op(s) from the current graph
        op = sess.graph.get_operation_by_name(op_name)
        bias_op = None
        if bias_op_name:
            bias_op = sess.graph.get_operation_by_name(bias_op_name)

        # Print current conv weight shape
        query = core.OpQuery(sess.graph)
        w_shape = query.get_weights_for_op(op).get_shape().as_list()
        logger.debug('Original %s weight shape: %s', op.name, str(w_shape))
        split_weights, weight_sizes = [], []
        split_biases, bias_sizes = [], []

        # FC  weights are: [w_shape[2],svd_ranks[0]] in [I,O] order.
        # We must reshape the split weights to SVD format [O,I] and then transpose to NHWC
        split_fc_a_w_shape = (svd_ranks[0], w_shape[0])
        fc_a_weights = np.zeros(split_fc_a_w_shape)
        fc_a_bias = np.zeros(svd_ranks[0])
        split_weights.append(fc_a_weights.flatten().tolist())
        weight_sizes.append(fc_a_weights.size)
        if bias_op:
            split_biases.append(fc_a_bias.flatten().tolist())
            bias_sizes.append(fc_a_bias.size)

        # FC b weights are: [svd_ranks[0],num_filters] in [H,W,I,O] order.
        # We must reshape the split weights to SVD format [O,I,H,W] and then transpose to NHWC
        split_fc_b_w_shape = (w_shape[1], svd_ranks[0])
        fc_b_weights = np.zeros(split_fc_b_w_shape)
        split_weights.append(fc_b_weights.flatten().tolist())
        weight_sizes.append(fc_b_weights.size)
        if bias_op:
            fc_b_bias = np.zeros(w_shape[1])
            split_biases.append(fc_b_bias.flatten().tolist())
            bias_sizes.append(fc_b_bias.size)

        # Split the weights and biases according to the number of layers and ranks
        split_weights = self._svd.SplitLayerWeights(op.name, split_weights, weight_sizes, svd_ranks)
        split_biases = self._svd.SplitLayerBiases(op.name, split_biases, bias_sizes, svd_ranks)

        if split_weights:
            fc_a_name = op.name+'_a'
            fc_a_weights = np.array(split_weights[0]).reshape(split_fc_a_w_shape).transpose(1, 0)
            fc_a_w = tf.Variable(initial_value=fc_a_weights, name=fc_a_name+'_w', dtype=tf.float32)
            logger.debug('%s weight shape: %s', fc_a_name, str(fc_a_weights.shape))

            # Create fc_a using default strides (1,1)
            fc_acts = tf.matmul(op.inputs[0], fc_a_w, name=fc_a_name)
            if bias_op:
                fc_a_bias = tf.Variable(initial_value=split_biases[0], name=fc_a_name+'_bias', dtype=tf.float32)
                fc_acts = fc_acts + fc_a_bias

        if len(split_weights) > 1:
            # Create fc_b
            fc_b_name = op.name+'_b'
            fc_b_weights = np.array(split_weights[1]).reshape(split_fc_b_w_shape).transpose(1, 0)
            fc_b_w = tf.Variable(initial_value=fc_b_weights, name=fc_b_name+'_w', dtype=tf.float32)
            logger.debug('%s weight shape: %s', fc_b_name, str(fc_b_weights.shape))
            fc_acts = tf.matmul(fc_acts, fc_b_w, name=fc_b_name)
            if bias_op:
                fc_b_bias = tf.Variable(initial_value=split_biases[1], name=fc_b_name+'_bias', dtype=tf.float32)
                fc_acts = fc_acts + fc_b_bias
        ratio = self._compute_per_layer_compression_ratio([fc_a_w.shape, fc_b_w.shape], fc_acts.shape, w_shape, 'MatMul')
        consumers = []
        rerouted_inputs = [bias_op.outputs[0]] if bias_op else [op.outputs[0]]
        for inp in rerouted_inputs:
            for consumer in inp.consumers():
                consumers.append(consumer)
        _ = ge.reroute_ts(fc_acts, rerouted_inputs, can_modify=consumers)
        return ratio
Beispiel #34
0
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    '''
    Authors: Tim Salimans & Yaroslav Bulatov

    memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost"
    by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)

    ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
    (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)

    'checkpoints' can either be
        - a list consisting of tensors from the forward pass of the neural net
          that we should re-use when calculating the gradients in the backward pass
          all other tensors that do not appear in this list will be re-computed
        - a string specifying how this list should be determined. currently we support
            - 'speed':  checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive,
                        so checkpointing them maximizes the running speed
                        (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory)
            - 'memory': try to minimize the memory usage
                        (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint)
            - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint
    '''

    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    debug_print("bwd_ops: %s", bwd_ops)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    debug_print("fwd_ops: %s", fwd_ops)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    checkpoints = 'collection'
    # construct list of tensors to checkpoint during forward pass, if not
    # given as input

    stereo_checkpoints = ge.filter_ts_from_regex(fwd_ops, "add")
    motion_checkpoints = ge.filter_ts_from_regex(fwd_ops, "Conv2D")

    my_ckps = []
    for x in motion_checkpoints:
        if ("motion" in x.name) and ("BatchNorm" not in x.name):
            my_ckps.append(x)
    for x in stereo_checkpoints:
        if ("stereo" in x.name) and ("BatchNorm" not in x.name):
            my_ckps.append(x)

    checkpoints = my_ckps
    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: %s", checkpoints)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print("Warning, some input nodes are also checkpoint nodes: %s",
                    xs_intersect_checkpoints)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints,
                ys_intersect_checkpoints)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: %s",
            format_ops(ys_intersect_checkpoints))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s",
                len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)
    debug_print("ops_to_copy = %s", ops_to_copy)
    debug_print("Processing list %s", ys)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    copied_ops = info._transformed_ops.values()
    debug_print("Copied %s to %s", ops_to_copy, copied_ops)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired %s in place of %s restricted to %s",
                checkpoints_disconnected.values(),
                checkpoints_disconnected.keys(), copied_ops)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients %s", dv)
    debug_print("for %s", copied_ys)
    debug_print("with respect to %s", boundary + xs)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list %s", ts)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print("Found %s ops to copy within %s, seed %s, stop_at %s",
                    len(ops_to_copy), fwd_ops, [r.op for r in ts],
                    checkpoints_other)
        debug_print("ops_to_copy = %s", ops_to_copy)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        copied_ops = info._transformed_ops.values()
        debug_print("Copied %s to %s", ops_to_copy, copied_ops)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print("Rewired %s in place of %s restricted to %s",
                    checkpoints_disconnected_other, checkpoints_other,
                    copied_ops)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients %s", dv)
        debug_print("for %s", boundary)
        debug_print("with respect to %s", checkpoints_disconnected_other + xs)
        debug_print("with boundary backprop substitutions %s",
                    substitute_backprops)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = d_xs_new[j]
                else:
                    d_xs[j] += d_xs_new[j]

    return d_xs
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
                                 fused_batch_norm):
  """Computes batch norm correction params.

     Before batch normalization is frozen:
     We use batch statistics for batch norm.
       correction_scale = sigma_b/sigma_mv
       correction_recip = 1/correction_scale
       correction_offset = 0

     After batch normalization is frozen:
      correction_scale = sigma_b/sigma_mv
      correction_recip = 1
      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).

     Batch norm is frozen if global_step > bn_freeze_delay.
     The corrections ensure that:
     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
     smoother training as the scaling on the weights changes slowly, rather than
     jump across mini-batches
     b) Changing the values of the corrections allows for one to switch between
     using batch statistics to using moving mean and average, without requiring
     changes to batch_norm


  Args:
    context: The scope under which we look for batch norm params
    match: Object containing required batch norm tensors for correction
      computation.
    freeze_batch_norm_delay: Delay in steps at which computation switches
      from regular batch norm to frozen mean and variance.
    fused_batch_norm: Bool, true if fused batch norm is used.

  Returns:
    A tuple of correction_scale, correction_recip, correction_offset
  """

  g = ops.get_default_graph()
  prefix = '' if not context else context + '/'
  with g.name_scope(prefix + 'batch_norm_correction'):
    recip_sigma_mv = math_ops.rsqrt(
        match.moving_variance_tensor + match.batch_epsilon)
    recip_sigma = math_ops.rsqrt(match.variance_tensor + match.batch_epsilon)
    correction_scale = math_ops.divide(
        recip_sigma_mv, recip_sigma, name='scale_compute')
    correction_scale = array_ops.identity(
        correction_scale, name='correction_scale')
    correction_recip = math_ops.reciprocal(
        correction_scale, name='reciprocal_compute')
    correction_offset = math_ops.multiply(
        match.gamma_tensor,
        match.mean_tensor * recip_sigma -
        match.moving_mean_tensor * recip_sigma_mv,
        name='offset_compute')

    if freeze_batch_norm_delay is not None:
      use_mv_avg = math_ops.greater_equal(
          common.CreateOrGetQuantizationStep(),
          freeze_batch_norm_delay,
          name='use_moving_average')
    else:
      use_mv_avg = False

    bn_decay_zero = 0.0
    bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
    bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())

    bn_decay_mean_out = utils.smart_cond(
        use_mv_avg,
        lambda: bn_decay_zero,
        lambda: match.bn_decay_mean_tensor,
        name='freeze_moving_mean')
    graph_editor.reroute_ts(
        [bn_decay_mean_out], [match.bn_decay_mean_tensor],
        can_modify=bn_decay_mean_consumers)

    if fused_batch_norm is False:
      bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
      bn_decay_var_out = utils.smart_cond(
          use_mv_avg,
          lambda: bn_decay_zero,
          lambda: match.bn_decay_var_tensor,
          name='freeze_moving_var')
      graph_editor.reroute_ts(
          [bn_decay_var_out], [match.bn_decay_var_tensor],
          can_modify=bn_decay_var_consumers)

    correction_recip = utils.smart_cond(
        use_mv_avg,
        lambda: array_ops.ones(correction_scale.shape),
        lambda: correction_recip,
        name='correction_recip')

    correction_offset = utils.smart_cond(
        use_mv_avg,
        lambda: correction_offset,
        lambda: array_ops.zeros(correction_offset.shape),
        name='correction_offset')
  return correction_scale, correction_recip, correction_offset
Beispiel #36
0
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    print("-------------------------------")
    debug_print("Editing model for OME")
    incpu_count = 0
    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    for index, op in enumerate(bwd_ops):
        debug_print("bwd_ops: [{}] :{}".format(index, op.name), 1)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    for index, op in enumerate(fwd_ops):
        debug_print("fwd_ops: [{}] : {}".format(index, op.name), 1)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        # remove very small tensors and some weird ops
        def fixdims(
            t
        ):  # tf.Dimension values are not compatible with int, convert manually
            try:
                return [int(e if e.value is not None else 64) for e in t]
            except:
                return [0]  # unknown shape

        ts_all = [
            t for t in ts_all
            if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE
        ]
        ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
        ts_all = [t for t in ts_all if 'entropy' not in t.name]
        ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
        ts_all = [t for t in ts_all if 'Switch' not in t.name]
        ts_all = [t for t in ts_all if 'dropout' not in t.name]
        # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
        ts_all = [t for t in ts_all if 'Cast' not in t.name]

        # filter out all tensors that are inputs of the backward graph
        with util.capture_ops() as bwd_ops:
            tf_gradients(ys, xs, grad_ys, **kwargs)

        bwd_inputs = [t for op in bwd_ops for t in op.inputs]
        # list of tensors in forward graph that is in input to bwd graph
        ts_filtered = list(set(bwd_inputs).intersection(ts_all))
        debug_print("Using tensors {}".format(ts_filtered), 1)

        # try two slightly different ways of getting bottlenecks tensors
        # to checkpoint
        for ts in [ts_filtered, ts_all]:

            # get all bottlenecks in the graph
            bottleneck_ts = []
            for t in ts:
                b = set(
                    ge.get_backward_walk_ops(t.op,
                                             inclusive=True,
                                             within_ops=fwd_ops))
                f = set(
                    ge.get_forward_walk_ops(t.op,
                                            inclusive=False,
                                            within_ops=fwd_ops))
                # check that there are not shortcuts
                b_inp = set([inp for op in b
                             for inp in op.inputs]).intersection(ts_all)
                f_inp = set([inp for op in f
                             for inp in op.inputs]).intersection(ts_all)
                if not set(b_inp).intersection(
                        f_inp) and len(b_inp) + len(f_inp) >= len(ts_all):
                    bottleneck_ts.append(t)  # we have a bottleneck!
                else:
                    debug_print(
                        "Rejected bottleneck candidate and ops {}".format(
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp))),
                        2)

            # success? or try again without filtering?
            if len(bottleneck_ts) >= np.sqrt(
                    len(ts_filtered)):  # yes, enough bottlenecks found!
                break

        if not bottleneck_ts:
            raise Exception(
                'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".'
            )

        # sort the bottlenecks
        bottlenecks_sorted_lists = tf_toposort(bottleneck_ts,
                                               within_ops=fwd_ops)
        sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts]

        # save an approximately optimal number ~ sqrt(N)
        N = len(ts_filtered)
        if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
            checkpoints = sorted_bottlenecks
        else:
            step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
            checkpoints = sorted_bottlenecks[step::step]

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: {}".format(checkpoints), 1)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print(
            "Warning, some input nodes are also checkpoint nodes: {}".format(
                xs_intersect_checkpoints), 2)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print(
        "ys: {}, checkpoints: {}, intersect: {}".format(
            ys, checkpoints, ys_intersect_checkpoints), 1)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: {}".format(
                format_ops(ys_intersect_checkpoints)), 2)

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    if not checkpoints:
        raise Exception('no checkpoints nodes found or given as input! ')

    debug_print(
        "Select {} nodes to checkpoint nodes.".format(len(checkpoints)), 0)

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        frontier_ops = set(graph_util.get_consuming_ops(x.op.outputs))
        debug_print("my frontier ops: {}".format(frontier_ops), 1)

        bw_frontier_ops = frontier_ops & set(bwd_ops)
        debug_print("my bw frontier ops: {}".format(bw_frontier_ops), 1)

        if len(bw_frontier_ops) > 1:
            continue

        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)

        swapout_op = _add_swapout(grad_node.op, grad_node.op.outputs)
        incpu_count = incpu_count + 1
        swapin_op = _add_swapin(swapout_op, bw_frontier_ops,
                                grad_node.op.outputs)
        checkpoints_disconnected[x] = swapin_op
        my_add_control_inputs(x, bw_frontier_ops, swapin_op)
        # control dependency -> swap_in
        # self._add_control_dependency(src_op, dest_op, swapin_op)

    # g = tf.get_default_graph()
    # print(g.get_operations())

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print(
        "Found {} ops to copy within fwd_ops {}, seed {}, stop_at {}".format(
            len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints), 1)
    debug_print("ops_to_copy = {}".format(ops_to_copy), 1)
    debug_print("Processing list {}".format(ys), 1)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print(
        "Rewired {} in place of {} restricted to {}".format(
            checkpoints_disconnected.values(), checkpoints_disconnected.keys(),
            copied_ops), 1)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients {}".format(dv), 1)
    debug_print("for {}".format(copied_ys), 1)
    debug_print("with respect to {}".format(boundary + xs), 1)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list {}".format(ts), 1)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print(
            "Found {} ops to copy within {}, seed {}, stop_at {}".format(
                len(ops_to_copy), fwd_ops, [r.op for r in ts],
                checkpoints_other), 1)
        debug_print("ops_to_copy = {}".format(ops_to_copy), 1)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied {} to {}".format(ops_to_copy, copied_ops), 1)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print(
            "Rewired {} in place of {} restricted to {}".format(
                checkpoints_disconnected_other, checkpoints_other, copied_ops),
            1)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients {}".format(dv), 1)
        debug_print("for {}".format(boundary), 1)
        debug_print(
            "with respect to {}".format(checkpoints_disconnected_other + xs),
            1)
        debug_print(
            "with boundary backprop substitutions {}".format(
                substitute_backprops), 1)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(x):
            if not isinstance(x, tf.IndexedSlices):
                return x
            assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = x.indices
            while indices.shape.ndims < x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, x.values, x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
  """Finds fused batch norm layers and folds them into preceding layers.

  Folding only affects the following layers: Conv2D, fully connected, depthwise
  convolution.

  Args:
    graph: Graph to walk and modify.
    is_training: Bool, true if training.
    freeze_batch_norm_delay: How many steps to wait before freezing moving mean
      and variance and using them for batch normalization.

  Raises:
    ValueError: When batch norm folding fails.
  """
  for match in _FindFusedBatchNorms(graph):
    scope, sep, _ = match.layer_op.name.rpartition('/')
    # Make sure new ops are added to `graph` and put on the same device as
    # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
    # named `scope`. Otherwise, TF creates a unique scope whose name starts with
    # `scope`.
    with graph.as_default(), graph.name_scope(scope + sep):
      with graph.name_scope(scope + sep + 'BatchNorm_Fold' + sep):
        # new weights = old weights * gamma / sqrt(variance + epsilon)
        # new biases = -mean * gamma / sqrt(variance + epsilon) + beta
        multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
            match.variance_tensor + match.bn_op.get_attr('epsilon'))
        bias_tensor = math_ops.subtract(
            match.beta_tensor,
            match.mean_tensor * multiplier_tensor,
            name='bias')

        correction_scale, correction_recip, correction_offset = None, None, None
        if is_training:
          correction_scale, correction_recip, correction_offset = (
              _ComputeBatchNormCorrections(
                  context='',
                  match=match,
                  freeze_batch_norm_delay=freeze_batch_norm_delay,
                  fused_batch_norm=True))
        # The shape of depthwise weights is different, so we need to reshape the
        # multiplier_tensor to ensure that the scaled_weight_tensor has the
        # expected shape.
        weights = match.weight_tensor
        if match.layer_op.type == 'DepthwiseConv2dNative':
          new_shape = [
              match.weight_tensor.get_shape().as_list()[2],
              match.weight_tensor.get_shape().as_list()[3]
          ]
          multiplier_tensor = array_ops.reshape(
              multiplier_tensor, new_shape, name='scale_reshape')

          if correction_scale is not None:
            correction_scale = array_ops.reshape(
                correction_scale, new_shape, name='correction_reshape')

      if correction_scale is not None:
        weights = math_ops.multiply(
            correction_scale, weights, name='correction_mult')

      scaled_weight_tensor = math_ops.multiply(
          weights, multiplier_tensor, name='mul_fold')
      new_layer_tensor = _CloneWithNewOperands(
          match.layer_op, match.input_tensor, scaled_weight_tensor)

      if correction_recip is not None:
        new_layer_tensor = math_ops.multiply(
            correction_recip, new_layer_tensor, name='post_conv_mul')
        new_layer_tensor = math_ops.add(new_layer_tensor, (correction_offset),
                                        'correction_add')

      bias_add_tensor = math_ops.add(
          new_layer_tensor, bias_tensor, name='add_fold')

      nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
                                                     match.output_tensor)
      if nodes_modified_count == 0:
        raise ValueError('Folding batch norms failed, %s had no outputs.' %
                         match.output_tensor.name)
Beispiel #38
0
def AddIntegratedGradientsOps(graph,
                              attribution_tensors,
                              output_tensor,
                              num_evals,
                              attribution_dims_map,
                              zero_baseline_tensors=None,
                              new_output_scope='attribution',
                              baseline_scope='baseline',
                              tensors_to_keep=None):
    """Modify graph to create ops for computing integrated gradients.

  Function to modify a tensorflow graph by adding ops for attributing the change
  in value of a given output tensor, to different input 'attribution_tensors'
  (see arxiv.org/abs/1703.01365).

  The first dimension of each attribution_tensor and output_tensor is assumed
  to be the batch dimension. That is, if we create multiple input values for the
  attribution tensors, we should be able to concatenate them along the first
  dimension, and the resulting output tensor should have corresponding values
  for different values of its first dimension.

  The attribution works by interpolating between a given input, and a given
  baseline, to create multiple (num_evals) interpolated inputs. At each
  interpolated input, we compute the gradient of the output tensor with respect
  to each attribution tensor. The gradients for each attribution tensor are
  averaged over all interpolated inputs, to get an attribution score for it.

  Example Usage: attribution_feed_dict = AddIntegratedGradientsOps(...)
  Then to get attribution for a given input (specificed by input_feed_dict,
  relative to a baseline given be baseline_feed_dict):
  combined_feed_dict = attribution_feed_dict['create_combined_feed_dict'](
      input_feed_dict, baseline_feed_dict)
  with graph.as_default(), sess.as_default():
    attributions = sess.run(
        attribution_feed_dict['mean_grads'], combined_feed_dict)
  for tensor, attribution in zip(attribution_tensors, attributions):
    print('Attribution for %s: %s' % (tensor.op.name, attribution))

  Warning: This function is not compatible with tf.cond. If there is a tf.cond
  in the graph path between the attribution tensors and the output tensor, the
  attribution ops may not work.
  # TODO(manasrj): Make attribution ops compatible with tf.cond.

  Args:
    graph: The tf.Graph to add attribution ops to.
    attribution_tensors: Tensors for which to compute attribution scores. The
      tensors must satisfy two properties: (1) The output tensor must
      be computable given values for attribution tensors. (2) Each
      attribution tensor must be computationally independent of the
      others, i.e., it should not be the case that one of the
      attribution tensor's value is completely determined by the
      values of the other attribution tensors. Properties (1) and (2) ensure
      the attribution tensors form an input-output cut in the computation
      graph.
    output_tensor: Tensor for whose value we are performing attribution.
    num_evals: Integer scalar. Number of interpolated points at which to
      evaluate gradients. Higher values of this parameter increase computation
      time, but also increase accuracy of attributions.
    attribution_dims_map: Dict mapping attribution tensors to lists of integers.
      For each attribution_tensor, we compute a separate gradient value for each
      slice along the dims in the list. For example, if we have a rank 3
      attribution tensor T that consists of embeddings lookups, with the first
      dimension being the batch dimension, and the second dimension being the
      sparse ids, then setting attribution_dims_map[T] = [1] will give us a
      separate gradient for each sparse id. If an attribution_tensor has no
      entry in attribution_dims_map, then the list defaults to [].
    zero_baseline_tensors: Set of attribution tensors. For each tensor T in this
      set, we compute gradients with respect to T for all interpolated values of
      T between the value computed from the input feed, and zero. For each
      tensor U not in zero_baseline_tensors, we compute gradients for
      interpolated values between the one derived from the input feed, and the
      one derived from the baseline feed.
    new_output_scope: String. New ops needed for computing the output tensor at
      different interpolated values are created under this scope name.
    baseline_scope: String. New ops needed for computing attribution tensor
      interpolated values are created under this scope name.
    tensors_to_keep: Set of tensors. By default, tensors in the graph between
      the output_tensor and attribution tensors are copied to a different part
      of the graph, and evaluated separately for each interpolation. If we want
      a value to be fixed (only computed for the main input instead of each
      interpolation), it should be put in tensors_to_keep.

  Returns:
    attribution_hooks: Dict with the following keys (among others):
      mean_grads: List of attribution scores (aligned with attribution_tensors).
      create_combined_feed_dict: A Function that takes an input feed dict, and
        optionally, a baseline feed dict, and creates a combined feed dict to
        pass to sess.run to get attributions.
  """
    ops_to_tensors = lambda ops: [op.outputs[0] for op in ops]
    attribution_hooks = {}
    if tensors_to_keep is None:
        tensors_to_keep = []
    else:
        tensors_to_keep = list(tensors_to_keep)
    if zero_baseline_tensors is None:
        zero_baseline_tensors = []
    with graph.as_default():
        # Compute parts of graph and check correctness.
        all_ops = graph.get_operations()
        constant_ops = contrib_graph_editor.select.select_ops(
            all_ops, positive_filter=lambda x: x.type == 'Const')
        placeholder_ops = contrib_graph_editor.select.select_ops(
            all_ops, positive_filter=lambda x: x.type == 'Placeholder')
        var_read_ops = contrib_graph_editor.select.select_ops('/read$',
                                                              graph=graph)
        attr_ops = [t.op for t in attribution_tensors]
        required_ops = set(
            contrib_graph_editor.select.get_backward_walk_ops(
                output_tensor.op,
                stop_at_ts=(tensors_to_keep + list(attribution_tensors) +
                            ops_to_tensors(var_read_ops) +
                            ops_to_tensors(placeholder_ops))))

        # Check that attribution tensors are sufficient to compute output_tensor.
        forward_ops = set(
            contrib_graph_editor.select.get_forward_walk_ops(attr_ops +
                                                             var_read_ops +
                                                             constant_ops))
        assert required_ops.issubset(forward_ops)

        required_sgv = contrib_graph_editor.subgraph.make_view(required_ops)
        attribution_subgraph, attribution_transform_info = (
            contrib_graph_editor.transform.copy_with_input_replacements(
                required_sgv, {}, graph, new_output_scope))
        attribution_hooks['attribution_subgraph'] = attribution_subgraph
        attribution_hooks[
            'attribution_transform_info'] = attribution_transform_info

        # Copy feed to attribution part of graph so we can have one part for
        # baseline and one for input.
        backward_ops = contrib_graph_editor.select.get_backward_walk_ops(
            attr_ops, stop_at_ts=ops_to_tensors(var_read_ops))
        backward_sgv = contrib_graph_editor.subgraph.make_view(backward_ops)
        _, baseline_transform_info = (
            contrib_graph_editor.transform.copy_with_input_replacements(
                backward_sgv, {}, graph, baseline_scope))
        attribution_hooks['baseline_transform_info'] = baseline_transform_info

        # Function to compute combined feed dict. The default setting of
        # baseline_transform_info is to get around python's late binding.
        def CreateCombinedFeedDict(
                input_feed_dict,
                baseline_feed_dict=None,
                baseline_transform_info=baseline_transform_info):
            """Combine baseline and input feed dicts into a common feed dict."""
            combined_feed_dict = input_feed_dict.copy()
            if baseline_feed_dict is None:
                baseline_feed_dict = input_feed_dict
            for key, feed_value in baseline_feed_dict.items():
                if isinstance(key, tf.Tensor):
                    combined_feed_dict[baseline_transform_info.transformed(
                        key)] = (feed_value)
                elif isinstance(key, six.text_type):
                    if six.PY2:
                        tensor = graph.get_tensor_by_name(key.decode())
                    else:
                        tensor = graph.get_tensor_by_name(key)
                    combined_feed_dict[baseline_transform_info.transformed(
                        tensor)] = (feed_value)
                elif isinstance(key, tf.SparseTensor):
                    sparse_transformed_tensor = tf.SparseTensor(
                        baseline_transform_info.transformed(key.indices),
                        baseline_transform_info.transformed(key.values),
                        baseline_transform_info.transformed(key.dense_shape))
                    combined_feed_dict[sparse_transformed_tensor] = feed_value
                else:
                    raise ValueError('Invalid key type %s in Feed Dict.' %
                                     type(key))
            return combined_feed_dict

        attribution_hooks['create_combined_feed_dict'] = CreateCombinedFeedDict

        # Create new tensors with the multipliers to insert after previous ones.
        attribution_hooks['multipliers'] = []
        attribution_hooks['weighted_attribution_tensors'] = []
        for attribution_tensor in attribution_tensors:
            with tf.control_dependencies(
                [tf.assert_equal(tf.shape(attribution_tensor)[0], 1)]):
                attribution_dims = (attribution_dims_map[attribution_tensor]
                                    if attribution_tensor
                                    in attribution_dims_map else [])
                vocab_size = len(attribution_tensor.get_shape())
                attribution_dim_cond = tf.sparse_to_indicator(
                    tf.SparseTensor(
                        tf.expand_dims(
                            tf.range(len(attribution_dims), dtype=tf.int64),
                            1), attribution_dims, [vocab_size]), vocab_size)
                base_multiplier_shape = tf.concat([
                    tf.expand_dims(num_evals, 0),
                    tf.ones_like(tf.shape(attribution_tensor))[1:]
                ], 0)
                tile_dims = tf.where(
                    attribution_dim_cond, tf.shape(attribution_tensor),
                    tf.ones_like(tf.shape(attribution_tensor)))
                pre_tile_multiplier = tf.reshape(
                    tf.range(tf.to_float(num_evals)) /
                    tf.to_float(num_evals - 1), base_multiplier_shape)
                multiplier = tf.tile(pre_tile_multiplier, tile_dims)
                if attribution_tensor in zero_baseline_tensors:
                    weighted_attribution_tensor = multiplier * attribution_tensor
                else:
                    base_attribution_tensor = baseline_transform_info.transformed(
                        attribution_tensor)
                    weighted_attribution_tensor = (
                        multiplier * attribution_tensor +
                        (1 - multiplier) * base_attribution_tensor)
                attribution_hooks['weighted_attribution_tensors'].append(
                    weighted_attribution_tensor)
                attribution_hooks['multipliers'].append(multiplier)

        contrib_graph_editor.reroute_ts(
            attribution_hooks['weighted_attribution_tensors'],
            attribution_tensors,
            can_modify=attribution_subgraph.ops)
        g = tf.gradients(attribution_transform_info.transformed(output_tensor),
                         attribution_hooks['multipliers'])
        attribution_hooks['mean_grads'] = [
            tf.reduce_mean(grad, 0) for grad in g
        ]
    return attribution_hooks
def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
    '''
    Authors: Tim Salimans & Yaroslav Bulatov

    memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost"
    by Chen et al. 2016 (https://arxiv.org/abs/1604.06174)

    ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients
    (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients)

    'checkpoints' can either be
        - a list consisting of tensors from the forward pass of the neural net
          that we should re-use when calculating the gradients in the backward pass
          all other tensors that do not appear in this list will be re-computed
        - a string specifying how this list should be determined. currently we support
            - 'speed':  checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive,
                        so checkpointing them maximizes the running speed
                        (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory)
            - 'memory': try to minimize the memory usage
                        (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint)
            - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint
    '''

    #    print("Calling memsaving gradients with", checkpoints)
    if not isinstance(ys, list):
        ys = [ys]
    if not isinstance(xs, list):
        xs = [xs]

    bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], inclusive=True)

    debug_print("bwd_ops: %s", bwd_ops)

    # forward ops are all ops that are candidates for recomputation
    fwd_ops = ge.get_forward_walk_ops([x.op for x in xs],
                                      inclusive=True,
                                      within_ops=bwd_ops)
    debug_print("fwd_ops: %s", fwd_ops)

    # exclude ops with no inputs
    fwd_ops = [op for op in fwd_ops if op.inputs]

    # don't recompute xs, remove variables
    xs_ops = _to_ops(xs)
    fwd_ops = [op for op in fwd_ops if not op in xs_ops]
    fwd_ops = [op for op in fwd_ops if not '/assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name]
    fwd_ops = [op for op in fwd_ops if not '/read' in op.name]
    ts_all = ge.filter_ts(fwd_ops, True)  # get the tensors
    ts_all = [t for t in ts_all if '/read' not in t.name]
    ts_all = set(ts_all) - set(xs) - set(ys)

    # construct list of tensors to checkpoint during forward pass, if not
    # given as input
    if type(checkpoints) is not list:
        if checkpoints == 'collection':
            checkpoints = tf.get_collection('checkpoints')

        elif checkpoints == 'speed':
            # checkpoint all expensive ops to maximize running speed
            checkpoints = ge.filter_ts_from_regex(fwd_ops,
                                                  'conv2d|Conv|MatMul')

        elif checkpoints == 'memory':

            # remove very small tensors and some weird ops
            def fixdims(
                t
            ):  # tf.Dimension values are not compatible with int, convert manually
                try:
                    return [int(e if e.value is not None else 64) for e in t]
                except:
                    return [0]  # unknown shape

            ts_all = [
                t for t in ts_all
                if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE
            ]
            ts_all = [t for t in ts_all if 'L2Loss' not in t.name]
            ts_all = [t for t in ts_all if 'entropy' not in t.name]
            ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name]
            ts_all = [t for t in ts_all if 'Switch' not in t.name]
            ts_all = [t for t in ts_all if 'dropout' not in t.name]
            # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
            ts_all = [t for t in ts_all if 'Cast' not in t.name]

            # filter out all tensors that are inputs of the backward graph
            with util.capture_ops() as bwd_ops:
                tf_gradients(ys, xs, grad_ys, **kwargs)

            bwd_inputs = [t for op in bwd_ops for t in op.inputs]
            # list of tensors in forward graph that is in input to bwd graph
            ts_filtered = list(set(bwd_inputs).intersection(ts_all))
            debug_print("Using tensors %s", ts_filtered)

            # try two slightly different ways of getting bottlenecks tensors
            # to checkpoint
            for ts in [ts_filtered, ts_all]:

                # get all bottlenecks in the graph
                bottleneck_ts = []
                for t in ts:
                    b = set(
                        ge.get_backward_walk_ops(t.op,
                                                 inclusive=True,
                                                 within_ops=fwd_ops))
                    f = set(
                        ge.get_forward_walk_ops(t.op,
                                                inclusive=False,
                                                within_ops=fwd_ops))
                    # check that there are not shortcuts
                    b_inp = set([inp for op in b
                                 for inp in op.inputs]).intersection(ts_all)
                    f_inp = set([inp for op in f
                                 for inp in op.inputs]).intersection(ts_all)
                    if not set(b_inp).intersection(
                            f_inp) and len(b_inp) + len(f_inp) >= len(ts_all):
                        bottleneck_ts.append(t)  # we have a bottleneck!
                    else:
                        debug_print(
                            "Rejected bottleneck candidate and ops %s",
                            [t] + list(set(ts_all) - set(b_inp) - set(f_inp)))

                # success? or try again without filtering?
                if len(bottleneck_ts) >= np.sqrt(
                        len(ts_filtered)):  # yes, enough bottlenecks found!
                    break

            if not bottleneck_ts:
                raise Exception(
                    'unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".'
                )

            # sort the bottlenecks
            bottlenecks_sorted_lists = tf_toposort(bottleneck_ts,
                                                   within_ops=fwd_ops)
            sorted_bottlenecks = [
                t for ts in bottlenecks_sorted_lists for t in ts
            ]

            # save an approximately optimal number ~ sqrt(N)
            N = len(ts_filtered)
            if len(bottleneck_ts) <= np.ceil(np.sqrt(N)):
                checkpoints = sorted_bottlenecks
            else:
                step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N)))
                checkpoints = sorted_bottlenecks[step::step]

        else:
            raise Exception('%s is unsupported input for "checkpoints"' %
                            (checkpoints, ))

    checkpoints = list(set(checkpoints).intersection(ts_all))

    # at this point automatic selection happened and checkpoints is list of nodes
    assert isinstance(checkpoints, list)

    debug_print("Checkpoint nodes used: %s", checkpoints)
    # better error handling of special cases
    # xs are already handled as checkpoint nodes, so no need to include them
    xs_intersect_checkpoints = set(xs).intersection(set(checkpoints))
    if xs_intersect_checkpoints:
        debug_print("Warning, some input nodes are also checkpoint nodes: %s",
                    xs_intersect_checkpoints)
    ys_intersect_checkpoints = set(ys).intersection(set(checkpoints))
    debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints,
                ys_intersect_checkpoints)
    # saving an output node (ys) gives no benefit in memory while creating
    # new edge cases, exclude them
    if ys_intersect_checkpoints:
        debug_print(
            "Warning, some output nodes are also checkpoints nodes: %s",
            format_ops(ys_intersect_checkpoints))

    # remove initial and terminal nodes from checkpoints list if present
    checkpoints = list(set(checkpoints) - set(ys) - set(xs))

    # check that we have some nodes to checkpoint
    # if not checkpoints:
    #     raise Exception('no checkpoints nodes found or given as input! ')

    # disconnect dependencies between checkpointed tensors
    checkpoints_disconnected = {}
    for x in checkpoints:
        if x.op and x.op.name is not None:
            grad_node = tf.stop_gradient(x, name=x.op.name + "_sg")
        else:
            grad_node = tf.stop_gradient(x)
        checkpoints_disconnected[x] = grad_node

    # partial derivatives to the checkpointed tensors and xs
    ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys],
                                    stop_at_ts=checkpoints,
                                    within_ops=fwd_ops)
    debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s",
                len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints)
    debug_print("ops_to_copy = %s", ops_to_copy)
    debug_print("Processing list %s", ys)
    copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {})
    for origin_op, op in info._transformed_ops.items():
        op._set_device(origin_op.node_def.device)
    copied_ops = info._transformed_ops.values()
    debug_print("Copied %s to %s", ops_to_copy, copied_ops)
    ge.reroute_ts(checkpoints_disconnected.values(),
                  checkpoints_disconnected.keys(),
                  can_modify=copied_ops)
    debug_print("Rewired %s in place of %s restricted to %s",
                checkpoints_disconnected.values(),
                checkpoints_disconnected.keys(), copied_ops)

    # get gradients with respect to current boundary + original x's
    copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys]
    boundary = list(checkpoints_disconnected.values())
    dv = tf_gradients(ys=copied_ys,
                      xs=boundary + xs,
                      grad_ys=grad_ys,
                      **kwargs)
    debug_print("Got gradients %s", dv)
    debug_print("for %s", copied_ys)
    debug_print("with respect to %s", boundary + xs)

    inputs_to_do_before = [y.op for y in ys]
    if grad_ys is not None:
        inputs_to_do_before += grad_ys
    wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
    my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

    # partial derivatives to the checkpointed nodes
    # dictionary of "node: backprop" for nodes in the boundary
    d_checkpoints = {
        r: dr
        for r, dr in zip(checkpoints_disconnected.keys(),
                         dv[:len(checkpoints_disconnected)])
    }
    # partial derivatives to xs (usually the params of the neural net)
    d_xs = dv[len(checkpoints_disconnected):]

    # incorporate derivatives flowing through the checkpointed nodes
    checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops)
    for ts in checkpoints_sorted_lists[::-1]:
        debug_print("Processing list %s", ts)
        checkpoints_other = [r for r in checkpoints if r not in ts]
        checkpoints_disconnected_other = [
            checkpoints_disconnected[r] for r in checkpoints_other
        ]

        # copy part of the graph below current checkpoint node, stopping at
        # other checkpoints nodes
        ops_to_copy = fast_backward_ops(within_ops=fwd_ops,
                                        seed_ops=[r.op for r in ts],
                                        stop_at_ts=checkpoints_other)
        debug_print("Found %s ops to copy within %s, seed %s, stop_at %s",
                    len(ops_to_copy), fwd_ops, [r.op for r in ts],
                    checkpoints_other)
        debug_print("ops_to_copy = %s", ops_to_copy)
        if not ops_to_copy:  # we're done!
            break
        copied_sgv, info = ge.copy_with_input_replacements(
            ge.sgv(ops_to_copy), {})
        for origin_op, op in info._transformed_ops.items():
            op._set_device(origin_op.node_def.device)
        copied_ops = info._transformed_ops.values()
        debug_print("Copied %s to %s", ops_to_copy, copied_ops)
        ge.reroute_ts(checkpoints_disconnected_other,
                      checkpoints_other,
                      can_modify=copied_ops)
        debug_print("Rewired %s in place of %s restricted to %s",
                    checkpoints_disconnected_other, checkpoints_other,
                    copied_ops)

        # gradient flowing through the checkpointed node
        boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts]
        substitute_backprops = [d_checkpoints[r] for r in ts]
        dv = tf_gradients(boundary,
                          checkpoints_disconnected_other + xs,
                          grad_ys=substitute_backprops,
                          **kwargs)
        debug_print("Got gradients %s", dv)
        debug_print("for %s", boundary)
        debug_print("with respect to %s", checkpoints_disconnected_other + xs)
        debug_print("with boundary backprop substitutions %s",
                    substitute_backprops)

        inputs_to_do_before = [d_checkpoints[r].op for r in ts]
        wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None]
        my_add_control_inputs(wait_to_do_ops, inputs_to_do_before)

        # partial derivatives to the checkpointed nodes
        for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]):
            if dr is not None:
                if d_checkpoints[r] is None:
                    d_checkpoints[r] = dr
                else:
                    d_checkpoints[r] += dr

        def _unsparsify(x):
            if not isinstance(x, tf.IndexedSlices):
                return x
            assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape"
            indices = x.indices
            while indices.shape.ndims < x.values.shape.ndims:
                indices = tf.expand_dims(indices, -1)
            return tf.scatter_nd(indices, x.values, x.dense_shape)

        # partial derivatives to xs (usually the params of the neural net)
        d_xs_new = dv[len(checkpoints_other):]
        for j in range(len(xs)):
            if d_xs_new[j] is not None:
                if d_xs[j] is None:
                    d_xs[j] = _unsparsify(d_xs_new[j])
                else:
                    d_xs[j] += _unsparsify(d_xs_new[j])

    return d_xs
Beispiel #40
0
def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
                                 fused_batch_norm):
    """Computes batch norm correction params.

     Before batch normalization is frozen:
     We use batch statistics for batch norm.
       correction_scale = sigma_b/sigma_mv
       correction_recip = 1/correction_scale
       correction_offset = 0

     After batch normalization is frozen:
      correction_scale = sigma_b/sigma_mv
      correction_recip = 1
      correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).

     Batch norm is frozen if global_step > bn_freeze_delay.
     The corrections ensure that:
     a) The weights are quantized after scaling by gamma/sigma_mv. This enables
     smoother training as the scaling on the weights changes slowly, rather than
     jump across mini-batches
     b) Changing the values of the corrections allows for one to switch between
     using batch statistics to using moving mean and average, without requiring
     changes to batch_norm


  Args:
    context: The scope under which we look for batch norm params
    match: Object containing required batch norm tensors for correction
      computation.
    freeze_batch_norm_delay: Delay in steps at which computation switches
      from regular batch norm to frozen mean and variance.
    fused_batch_norm: Bool, true if fused batch norm is used.

  Returns:
    A tuple of correction_scale, correction_recip, correction_offset
  """

    g = ops.get_default_graph()
    prefix = '' if not context else context + '/'
    with g.name_scope(prefix + 'batch_norm_correction'):
        recip_sigma_mv = math_ops.rsqrt(match.moving_variance_tensor +
                                        match.batch_epsilon)
        recip_sigma = math_ops.rsqrt(match.variance_tensor +
                                     match.batch_epsilon)
        correction_scale = math_ops.divide(recip_sigma_mv,
                                           recip_sigma,
                                           name='scale_compute')
        correction_scale = array_ops.identity(correction_scale,
                                              name='correction_scale')
        correction_recip = math_ops.reciprocal(correction_scale,
                                               name='reciprocal_compute')
        correction_offset = math_ops.multiply(
            match.gamma_tensor,
            match.mean_tensor * recip_sigma -
            match.moving_mean_tensor * recip_sigma_mv,
            name='offset_compute')

        if freeze_batch_norm_delay is not None:
            use_mv_avg = math_ops.greater_equal(
                common.CreateOrGetQuantizationStep(),
                freeze_batch_norm_delay,
                name='use_moving_average')
        else:
            use_mv_avg = False

        bn_decay_zero = 0.0
        bn_decay_mean_consumers = list(match.bn_decay_mean_tensor.consumers())
        bn_decay_var_consumers = list(match.bn_decay_mean_tensor.consumers())

        bn_decay_mean_out = utils.smart_cond(
            use_mv_avg,
            lambda: bn_decay_zero,
            lambda: match.bn_decay_mean_tensor,
            name='freeze_moving_mean')
        graph_editor.reroute_ts([bn_decay_mean_out],
                                [match.bn_decay_mean_tensor],
                                can_modify=bn_decay_mean_consumers)

        if fused_batch_norm is False:
            bn_decay_var_consumers = list(
                match.bn_decay_var_tensor.consumers())
            bn_decay_var_out = utils.smart_cond(
                use_mv_avg,
                lambda: bn_decay_zero,
                lambda: match.bn_decay_var_tensor,
                name='freeze_moving_var')
            graph_editor.reroute_ts([bn_decay_var_out],
                                    [match.bn_decay_var_tensor],
                                    can_modify=bn_decay_var_consumers)

        correction_recip = utils.smart_cond(
            use_mv_avg,
            lambda: array_ops.ones(correction_scale.shape),
            lambda: correction_recip,
            name='correction_recip')

        correction_offset = utils.smart_cond(
            use_mv_avg,
            lambda: correction_offset,
            lambda: array_ops.zeros(correction_offset.shape),
            name='correction_offset')
    return correction_scale, correction_recip, correction_offset