def test_train_max_steps_is_not_incremental(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions.train(
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)

    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions.train(
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=15)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(15, step)
  def test_train_skip_train_if_max_step_already_saved(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)

    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)
Example #3
0
 def test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     monitor = learn.monitors.CheckpointSaver(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     monitor.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       self._run(monitor, 1, self.train_op, sess)
       self._run(monitor, 2, self.train_op, sess)
       # Not saved
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       self._run(monitor, 3, self.train_op, sess)
       # saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       self._run(monitor, 4, self.train_op, sess)
       # Not saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       self._run(monitor, 5, self.train_op, sess)
       # saved
       self.assertEqual(5,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
 def test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(5,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
 def testGetTensor(self):
   checkpoint_dir = self.get_temp_dir()
   with self.test_session() as session:
     v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var2"), v2)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var3"), v3)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "useful_scope/var4"), v4)
 def testNoTensor(self):
   checkpoint_dir = self.get_temp_dir()
   with self.test_session() as session:
     _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
   with self.assertRaises(errors_impl.OpError):
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var5"), [])
Example #7
0
def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    if not tensor_name:
      variables = checkpoint_utils.list_variables(file_name)
      for name, shape in variables:
        print("%s\t%s" % (name, str(shape)))
    else:
      print("tensor_name: ", tensor_name)
      print(checkpoint_utils.load_variable(file_name, tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")
  def test_save_secs_saves_periodically(self, mock_time):
    # Let's have a realistic start time
    current_time = 1484695987.209386

    with self.graph.as_default():
      mock_time.return_value = current_time
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir, save_secs=2, scaffold=self.scaffold)
      hook.begin()
      self.scaffold.finalize()

      with session_lib.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])

        mock_time.return_value = current_time
        mon_sess.run(self.train_op)  # Saved.

        mock_time.return_value = current_time + 0.5
        mon_sess.run(self.train_op)  # Not saved.

        self.assertEqual(1,
                         checkpoint_utils.load_variable(self.model_dir,
                                                        self.global_step.name))

        # Simulate 2.5 seconds of sleep.
        mock_time.return_value = current_time + 2.5
        mon_sess.run(self.train_op)  # Saved.

        mock_time.return_value = current_time + 2.6
        mon_sess.run(self.train_op)  # Not saved.

        mock_time.return_value = current_time + 2.7
        mon_sess.run(self.train_op)  # Not saved.

        self.assertEqual(3,
                         checkpoint_utils.load_variable(self.model_dir,
                                                        self.global_step.name))

        # Simulate 7.5 more seconds of sleep (10 seconds from start.
        mock_time.return_value = current_time + 10
        mon_sess.run(self.train_op)  # Saved.
        self.assertEqual(6,
                         checkpoint_utils.load_variable(self.model_dir,
                                                        self.global_step.name))
 def test_saves_when_saver_and_scaffold_both_missing(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=1)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
 def test_save_steps_saves_in_first_step(self):
     with self.graph.as_default():
         monitor = learn.monitors.CheckpointSaver(self.model_dir,
                                                  save_steps=2,
                                                  scaffold=self.scaffold)
         monitor.begin()
         self.scaffold.finalize()
         with session_lib.Session() as sess:
             sess.run(self.scaffold.init_op)
             self._run(monitor, 1, self.train_op, sess)
             self.assertEqual(
                 1,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
 def test_save_steps_saves_in_first_step(self):
     with self.graph.as_default():
         hook = basic_session_run_hooks.CheckpointSaverHook(
             self.model_dir, save_steps=2, scaffold=self.scaffold)
         hook.begin()
         self.scaffold.finalize()
         with session_lib.Session() as sess:
             sess.run(self.scaffold.init_op)
             mon_sess = monitored_session._HookedSession(sess, [hook])
             mon_sess.run(self.train_op)
             self.assertEqual(
                 1,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
Example #12
0
 def test_save_saves_at_end(self):
   with self.graph.as_default():
     monitor = learn.monitors.CheckpointSaver(
         self.model_dir, save_secs=2, scaffold=self.scaffold)
     monitor.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       self._run(monitor, 1, self.train_op, sess)
       self._run(monitor, 2, self.train_op, sess)
       monitor.end(sess)
       self.assertEqual(2,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
 def disabled_test_save_secs_saves_periodically(self):
     with self.graph.as_default():
         monitor = learn.monitors.CheckpointSaver(self.model_dir,
                                                  save_secs=2,
                                                  scaffold=self.scaffold)
         monitor.begin()
         self.scaffold.finalize()
         with session_lib.Session() as sess:
             sess.run(self.scaffold.init_op)
             self._run(monitor, 1, self.train_op, sess)
             self._run(monitor, 2, self.train_op, sess)
             # Not saved
             self.assertEqual(
                 1,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             time.sleep(2.5)
             self._run(monitor, 3, self.train_op, sess)
             # saved
             self.assertEqual(
                 3,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             self._run(monitor, 4, self.train_op, sess)
             self._run(monitor, 5, self.train_op, sess)
             # Not saved
             self.assertEqual(
                 3,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             time.sleep(2.5)
             self._run(monitor, 6, self.train_op, sess)
             # saved
             self.assertEqual(
                 6,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
 def load_trained_classifier(self, sess, run_id, target_category,
                             iteration):
     tvars = tf.trainable_variables(scope=self.classifierName)
     file_path = pathlib.Path(__file__).parent.absolute()
     model_folder = os.path.join(file_path, "..", "models", target_category)
     checkpoint_folder = os.path.join(
         model_folder, "lstm{0}_iteration{1}".format(run_id, iteration))
     model_path = os.path.join(
         checkpoint_folder,
         "lstm{0}_iteration{1}.ckpt".format(run_id, iteration))
     saved_vars = checkpoint_utils.list_variables(checkpoint_dir=model_path)
     for var in tvars:
         # assert len([_var for _var in saved_vars if _var.name == var.name]) == 1
         # if "Adam" in var.name:
         #     continue
         var_name = var.name[len(self.classifierName) + 1:]
         source_array = checkpoint_utils.load_variable(
             checkpoint_dir=model_path, name=var_name)
         tf.assign(var, source_array).eval(session=sess)
Example #15
0
def load_fp32_weights_into_fp16_vars(checkpoint_path: Path) -> List:
    """Load fp32 weights from checkpoint path into fp16 variables.

    Assumes that caller has executed `tf.run(tf.global_variables_initializer())`

    Args:
        checkpoint_path: Checkpoint path

    Returns:
        Collection of ops to use to restore the weights in the graph.
    """
    checkpoint_variables = [var_name for var_name, _ in list_variables(checkpoint_path)]

    for graph_var in tf.global_variables():
        if graph_var.op.name in checkpoint_variables:
            var = load_variable(checkpoint_path, graph_var.op.name)
            weights = tf.cast(var, tf.float16) if var.dtype == np.float32 else var
            tf.add_to_collection('restore_ops', graph_var.assign(weights))
    return tf.get_collection('restore_ops')
Example #16
0
 def init_graph():
     restore_vars = []
     layer_num = 0
     regex = re.compile('layer_\d+')
     amaxTotalNum = 0
     for name, shape in var_list:
         var = checkpoint_utils.load_variable(tf.flags.FLAGS.init_checkpoint, name)
         if "intermediate/dense/kernel" in name and amaxTotalNum == 0:
             amaxTotalNum = ACTIVATION_AMAX_NUM + 9*shape[0]
             print(amaxTotalNum, shape[0])
         recon_dtype = var.dtype
         restore_vars.append(tf.get_variable(name, shape=shape, dtype=var.dtype))
         tmp = regex.findall(name)
         if len(tmp) < 1:
             continue
         num_tmp = int(tmp[0].replace("layer_", ""))
         if layer_num < num_tmp:
             layer_num = num_tmp
     layer_num = layer_num + 1
     #add new var for amax
     for i in range(layer_num):
         tf.get_variable("bert/encoder/layer_{}/amaxList".format(i), shape=[amaxTotalNum], dtype=tf.float32)
     return layer_num, amaxTotalNum, restore_vars
Example #17
0
    def train(self):
        if tf.gfile.Exists(self.cluster_conf):
            cluster_spec, run_config, num_workers = self.gen_run_config()
            server = tf.train.Server(cluster_spec,
                                     job_name=self.job_name,
                                     task_index=self.task_index)
        else:
            run_config = tf.contrib.learn.RunConfig(
                save_checkpoints_secs=self.save_checkpoints_secs)
            num_workers = 1
            server = None

        print ('-' * 40)
        print ('run_config =', run_config)

        if self.job_name == 'ps':
            print ("ps start")
            server.join()
        elif self.job_name == 'worker':
            print ("worker start")
            train_steps = self.model_parameters.get("train_steps")
            self.model_fn_param['run_config'] = run_config
            builder = WideNDeepModelBuilder(model_desc=self.model_desc_obj)
            model = builder.build_estimator(self.model_fn_param)
            hooks = []
            try:
                global_step = checkpoint_utils.load_variable(
                        self.checkpoint_path, tf.GraphKeys.GLOBAL_STEP)
            except:
                global_step = 0
            print ('global_step =', global_step)
            if self.init_dir and global_step < 100:
                print ('InitEmbeddings from %s' % self.init_dir)
                hooks.append(InitEmbeddingsHook(self.init_dir))
            model.train(input_fn=self.input_fn_train,
                        steps=train_steps,
                        hooks=hooks)
Example #18
0
import tensorflow as tf
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.user_ops import user_ops
from tensorflow.python.training import queue_runner
from tensorflow.contrib.framework.python.framework import checkpoint_utils

#print checkpoint_utils.list_variables("hdfs://tensorflow-on-yarn-test/user/yonggang.myg/wnd_plugin_test/kafka_reader_test/checkpoint")
checkpoint_utils.list_variables("checkpoint")
print checkpoint_utils.load_variable("checkpoint",
                                     "queue_runner_checkpoint_var")

# tvars = tf.trainable_variables()
# tvars_vals = sess.run(tvars)
#
# for var, val in zip(tvars, tvars_vals):
#     print(var.name, val)
 def testNoCheckpoints(self):
   checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
   with self.assertRaises(errors_impl.OpError):
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var1"), [])
Example #20
0
 def covariances(self):
   """Returns the covariances."""
   return checkpoint_utils.load_variable(
       self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
def checkpoint_quantization(in_checkpoint_file, out_checkpoint_file,
                            per_channel_quantization):
    var_list = checkpoint_utils.list_variables(tf.flags.FLAGS.init_checkpoint)

    def init_graph():
        restore_vars = []
        layer_num = 0
        regex = re.compile('layer_\d+')
        amaxTotalNum = 0
        for name, shape in var_list:
            var = checkpoint_utils.load_variable(
                tf.flags.FLAGS.init_checkpoint, name)
            if "intermediate/dense/kernel" in name and amaxTotalNum == 0:
                amaxTotalNum = ACTIVATION_AMAX_NUM + 9 * shape[
                    0] + INT8O_GEMM_NUM + TRT_FUSED_MHA_AMAX_NUM
                print(amaxTotalNum, shape[0])
            recon_dtype = var.dtype
            restore_vars.append(
                tf.get_variable(name, shape=shape, dtype=var.dtype))
            tmp = regex.findall(name)
            if len(tmp) < 1:
                continue
            num_tmp = int(tmp[0].replace("layer_", ""))
            if layer_num < num_tmp:
                layer_num = num_tmp
        layer_num = layer_num + 1
        #add new var for amax
        for i in range(layer_num):
            tf.get_variable("bert/encoder/layer_{}/amaxList".format(i),
                            shape=[amaxTotalNum],
                            dtype=tf.float32)
        return layer_num, amaxTotalNum, restore_vars

    layer_num, amaxTotalNum, restore_vars = init_graph()
    restorer = tf.train.Saver(restore_vars)
    saver = tf.train.Saver()
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        restorer.restore(sess, in_checkpoint_file)
        kernel_name_list = [
            "attention/self/query", "attention/self/key",
            "attention/self/value", "attention/output/dense",
            "intermediate/dense", "output/dense"
        ]

        #input_scale, 0
        amax_name_list = [
            "attention/self/query/input_quantizer",
            #Q_aftergemm_scale, 1
            "attention/self/query/aftergemm_quantizer",
            #Qbias_scale, 2
            "attention/self/matmul_q_input_quantizer",
            #K_aftergemm_scale, 3
            "attention/self/key/aftergemm_quantizer",
            #Kbias_scale, 4
            "attention/self/matmul_k_input_quantizer",
            #V_aftergemm_scale, 5
            "attention/self/value/aftergemm_quantizer",
            #Vbias_scale, 6
            "attention/self/matmul_v_input_quantizer",
            #bmm1_scale, 7
            "attention/self/softmax_input_quantizer",
            #Softmax_scale, 8
            "attention/self/matmul_a_input_quantizer",
            #bmm2_scale, 9
            "attention/output/dense/input_quantizer",
            #Proj_aftergemm_scale, 10
            "attention/output/dense/aftergemm_quantizer",
            #ProjBiasNorm_scale, 11
            "intermediate/dense/input_quantizer",
            #FC1_aftergemm_scale, 12
            "intermediate/dense/aftergemm_quantizer",
            #F1Bias_scale, 13
            "output/dense/input_quantizer",
            #FC2_aftergemm_scale, 14
            "output/dense/aftergemm_quantizer",
            #F2Bias_scale, 15
            "special_F2Bias_scale",
        ]

        int8O_gemm_weight_amax_list = [0 for i in range(INT8O_GEMM_NUM)]
        #Q_aftergemm
        int8O_gemm_weight_list = [
            "attention/self/query",
            #K_aftergemm
            "attention/self/key",
            #V_aftergemm
            "attention/self/value",
            #bmm1_aftergemm
            "attention/self/matmul_k_input_quantizer",
            #bmm2_aftergemm
            "attention/self/matmul_v_input_quantizer",
            #Proj_aftergemm
            "attention/output/dense",
            #FC1_aftergemm
            "intermediate/dense",
            #FC2_aftergemm
            "output/dense"
        ]

        int8O_gemm_input_amax_list = [0 for i in range(INT8O_GEMM_NUM)]
        #Q_aftergemm
        int8O_gemm_input_list = [
            "attention/self/query/input_quantizer",
            #K_aftergemm
            "attention/self/key/input_quantizer",
            #V_aftergemm
            "attention/self/value/input_quantizer",
            #bmm1_aftergemm
            "attention/self/matmul_q_input_quantizer",
            #bmm2_aftergemm
            "attention/self/matmul_a_input_quantizer",
            #Proj_aftergemm
            "attention/output/dense/input_quantizer",
            #FC1_aftergemm
            "intermediate/dense/input_quantizer",
            #FC2_aftergemm
            "output/dense/input_quantizer"
        ]

        int8O_gemm_output_amax_list = [0 for i in range(INT8O_GEMM_NUM)]
        #Q_aftergemm
        int8O_gemm_output_list = [
            "attention/self/query/aftergemm_quantizer",
            #K_aftergemm
            "attention/self/key/aftergemm_quantizer",
            #V_aftergemm
            "attention/self/value/aftergemm_quantizer",
            #bmm1_aftergemm
            "attention/self/softmax_input_quantizer",
            #bmm2_aftergemm
            "attention/output/dense/input_quantizer",
            #Proj_aftergemm
            "attention/output/dense/aftergemm_quantizer",
            #FC1_aftergemm
            "intermediate/dense/aftergemm_quantizer",
            #FC2_aftergemm
            "output/dense/aftergemm_quantizer"
        ]

        factor = 1000000.0
        for i in range(layer_num):
            amaxList = np.zeros([amaxTotalNum])
            amax_id = 0
            for amax_name in amax_name_list:
                if amax_name == "special_F2Bias_scale":
                    if i != layer_num - 1:
                        name = "bert/encoder/layer_{}/{}/quant_max:0".format(
                            i + 1, amax_name_list[0])
                        quant_max = checkpoint_utils.load_variable(
                            tf.flags.FLAGS.init_checkpoint, name)
                        name = "bert/encoder/layer_{}/{}/quant_min:0".format(
                            i + 1, amax_name_list[0])
                        quant_min = checkpoint_utils.load_variable(
                            tf.flags.FLAGS.init_checkpoint, name)
                        if abs(quant_max) > abs(quant_min):
                            amax = abs(
                                quant_max)  #int(abs(quant_max)*factor)/factor
                        else:
                            amax = abs(
                                quant_min)  #int(abs(quant_min)*factor)/factor
                    else:

                        #not used, placeholder
                        amax = 1.0

                    amaxList[amax_id] = amax
                    amax_id += 1
                    amaxList[amax_id] = amax / 127.0
                    amax_id += 1
                    amaxList[amax_id] = amax / 127.0 / 127.0
                    amax_id += 1
                    amaxList[amax_id] = 127.0 / amax
                    amax_id += 1
                    continue

                name = "bert/encoder/layer_{}/{}/quant_max:0".format(
                    i, amax_name)
                quant_max = checkpoint_utils.load_variable(
                    tf.flags.FLAGS.init_checkpoint, name)
                name = "bert/encoder/layer_{}/{}/quant_min:0".format(
                    i, amax_name)
                quant_min = checkpoint_utils.load_variable(
                    tf.flags.FLAGS.init_checkpoint, name)

                if abs(quant_max) > abs(quant_min):
                    amax = abs(quant_max)  #int(abs(quant_max)*factor)/factor
                else:
                    amax = abs(quant_min)  #int(abs(quant_min)*factor)/factor

                if amax_name in int8O_gemm_input_list:
                    int8O_gemm_input_amax_list[int8O_gemm_input_list.index(
                        amax_name)] = amax
                    if amax_name == "attention/self/query/input_quantizer":
                        int8O_gemm_input_amax_list[int8O_gemm_input_list.index(
                            "attention/self/key/input_quantizer")] = amax
                        int8O_gemm_input_amax_list[int8O_gemm_input_list.index(
                            "attention/self/value/input_quantizer")] = amax

                if amax_name in int8O_gemm_output_list:
                    int8O_gemm_output_amax_list[int8O_gemm_output_list.index(
                        amax_name)] = amax

                if amax_name in int8O_gemm_weight_list:
                    int8O_gemm_weight_amax_list[int8O_gemm_weight_list.index(
                        amax_name)] = amax

                amaxList[amax_id] = amax
                amax_id += 1
                amaxList[amax_id] = amax / 127.0
                amax_id += 1
                amaxList[amax_id] = amax / 127.0 / 127.0
                amax_id += 1
                amaxList[amax_id] = 127.0 / amax
                amax_id += 1

            print("done process layer_{} activation amax".format(i))

            #kernel amax starts from ACTIVATION_AMAX_NUM
            amax_id = ACTIVATION_AMAX_NUM
            for kernel_id, kernel_name in enumerate(kernel_name_list):
                kernel = tf.get_default_graph().get_tensor_by_name(
                    "bert/encoder/layer_{}/{}/kernel:0".format(i, kernel_name))

                name = "bert/encoder/layer_{}/{}/kernel_quantizer/quant_max:0".format(
                    i, kernel_name)
                quant_max2 = tf.convert_to_tensor(
                    checkpoint_utils.load_variable(
                        tf.flags.FLAGS.init_checkpoint, name))

                name = "bert/encoder/layer_{}/{}/kernel_quantizer/quant_min:0".format(
                    i, kernel_name)
                quant_min2 = tf.convert_to_tensor(
                    checkpoint_utils.load_variable(
                        tf.flags.FLAGS.init_checkpoint, name))

                kernel_processed, quant_max_processed = transformer_op_module.weight_quantize(
                    kernel,
                    quant_max2,
                    quant_min2,
                    per_channel_quantization=per_channel_quantization)
                kernel_processed_, quant_max_processed_ = sess.run(
                    [kernel_processed, quant_max_processed])
                sess.run(tf.assign(kernel, kernel_processed_))
                if kernel_name in int8O_gemm_weight_list:
                    int8O_gemm_weight_amax_list[int8O_gemm_weight_list.index(
                        kernel_name)] = quant_max_processed_[0]
                for e in quant_max_processed_:
                    amaxList[amax_id] = e
                    amax_id += 1

            #for int8O gemm deQuant
            for j in range(INT8O_GEMM_NUM):
                amaxList[amax_id] = (int8O_gemm_input_amax_list[j] *
                                     int8O_gemm_weight_amax_list[j]) / (
                                         127.0 *
                                         int8O_gemm_output_amax_list[j])
                amax_id += 1

            #for trt fused MHA amax
            #### QKV_addBias_amax
            amaxList[amax_id] = np.maximum(
                np.maximum(amaxList[8], amaxList[16]), amaxList[24])
            amax_id += 1
            #### softmax amax
            amaxList[amax_id] = amaxList[32]
            amax_id += 1
            #### bmm2 amax
            amaxList[amax_id] = amaxList[36]
            amax_id += 1
            amaxL = tf.get_default_graph().get_tensor_by_name(
                "bert/encoder/layer_{}/amaxList:0".format(i))
            sess.run(tf.assign(amaxL, amaxList))

            print("done process layer_{} kernel weight".format(i))

        saver.save(sess, out_checkpoint_file)
Example #22
0
def load_variable(checkpoint_dir, name):
    """See `tf.contrib.framework.load_variable`."""
    return checkpoint_utils.load_variable(checkpoint_dir, name)
Example #23
0
 def clusters(self):
     """Returns cluster centers."""
     clusters = checkpoint_utils.load_variable(
         self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
     return np.squeeze(clusters, 1)
Example #24
0
 def init_graph():
     for name, shape in var_list:
         var = checkpoint_utils.load_variable(
             tf.flags.FLAGS.init_checkpoint, name)
         recon_dtype = tf.float16 if var.dtype == np.float32 else var.dtype
         tf.get_variable(name, shape=shape, dtype=recon_dtype)
 def testNoCheckpoints(self):
     checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
     with self.assertRaises(errors_impl.OpError):
         self.assertAllEqual(
             checkpoint_utils.load_variable(checkpoint_dir, "var1"), [])
Example #26
0
 def weights(self):
   """Returns the cluster weights."""
   return checkpoint_utils.load_variable(
       self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
Example #27
0
def _get_checkpointed_vars(checkpoints, name, fn=lambda x: x):
    return {
        k: fn(checkpoint_utils.load_variable(checkpoint_dir=k, name=name))
        for k in checkpoints
    }
Example #28
0
 def covariances(self):
     """Returns the covariances."""
     return checkpoint_utils.load_variable(
         self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
Example #29
0
 def clusters(self):
   """Returns cluster centers."""
   clusters = checkpoint_utils.load_variable(
       self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
   return np.squeeze(clusters, 1)
Example #30
0
 def weights(self):
     """Returns the cluster weights."""
     return checkpoint_utils.load_variable(
         self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
Example #31
0
def load_variable(checkpoint_dir, name):
  """See `tf.contrib.framework.load_variable`."""
  return checkpoint_utils.load_variable(checkpoint_dir, name)