def test_train_max_steps_is_not_incremental(self): with ops.Graph().as_default() as g, self.test_session(g): with ops.control_dependencies(self._build_inference_graph()): train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) learn.graph_actions.train( g, output_dir=self._output_dir, train_op=train_op, loss_op=constant_op.constant(2.0), max_steps=10) step = checkpoint_utils.load_variable( self._output_dir, variables_lib.get_global_step().name) self.assertEqual(10, step) with ops.Graph().as_default() as g, self.test_session(g): with ops.control_dependencies(self._build_inference_graph()): train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) learn.graph_actions.train( g, output_dir=self._output_dir, train_op=train_op, loss_op=constant_op.constant(2.0), max_steps=15) step = checkpoint_utils.load_variable( self._output_dir, variables_lib.get_global_step().name) self.assertEqual(15, step)
def test_train_skip_train_if_max_step_already_saved(self): with ops.Graph().as_default() as g, self.test_session(g): with ops.control_dependencies(self._build_inference_graph()): train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) learn.graph_actions._monitored_train( # pylint: disable=protected-access g, output_dir=self._output_dir, train_op=train_op, loss_op=constant_op.constant(2.0), max_steps=10) step = checkpoint_utils.load_variable( self._output_dir, variables_lib.get_global_step().name) self.assertEqual(10, step) with ops.Graph().as_default() as g, self.test_session(g): with ops.control_dependencies(self._build_inference_graph()): train_op = state_ops.assign_add(variables_lib.get_global_step(), 1) learn.graph_actions._monitored_train( # pylint: disable=protected-access g, output_dir=self._output_dir, train_op=train_op, loss_op=constant_op.constant(2.0), max_steps=10) step = checkpoint_utils.load_variable( self._output_dir, variables_lib.get_global_step().name) self.assertEqual(10, step)
def test_save_steps_saves_periodically(self): with self.graph.as_default(): monitor = learn.monitors.CheckpointSaver( self.model_dir, save_steps=2, scaffold=self.scaffold) monitor.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) self._run(monitor, 1, self.train_op, sess) self._run(monitor, 2, self.train_op, sess) # Not saved self.assertEqual(1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) self._run(monitor, 3, self.train_op, sess) # saved self.assertEqual(3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) self._run(monitor, 4, self.train_op, sess) # Not saved self.assertEqual(3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) self._run(monitor, 5, self.train_op, sess) # saved self.assertEqual(5, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def test_save_steps_saves_periodically(self): with self.graph.as_default(): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) mon_sess.run(self.train_op) # Not saved self.assertEqual(1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual(3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # Not saved self.assertEqual(3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) mon_sess.run(self.train_op) # saved self.assertEqual(5, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def testGetTensor(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1) self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "var2"), v2) self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "var3"), v3) self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "useful_scope/var4"), v4)
def testNoTensor(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: _, _, _, _ = _create_checkpoints(session, checkpoint_dir) with self.assertRaises(errors_impl.OpError): self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "var5"), [])
def print_tensors_in_checkpoint_file(file_name, tensor_name): """Prints tensors in a checkpoint file. If no `tensor_name` is provided, prints the tensor names and shapes in the checkpoint file. If `tensor_name` is provided, prints the content of the tensor. Args: file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. """ try: if not tensor_name: variables = checkpoint_utils.list_variables(file_name) for name, shape in variables: print("%s\t%s" % (name, str(shape))) else: print("tensor_name: ", tensor_name) print(checkpoint_utils.load_variable(file_name, tensor_name)) except Exception as e: # pylint: disable=broad-except print(str(e)) if "corrupted compressed block contents" in str(e): print("It's likely that your checkpoint file has been compressed " "with SNAPPY.")
def test_save_secs_saves_periodically(self, mock_time): # Let's have a realistic start time current_time = 1484695987.209386 with self.graph.as_default(): mock_time.return_value = current_time hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_secs=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mock_time.return_value = current_time mon_sess.run(self.train_op) # Saved. mock_time.return_value = current_time + 0.5 mon_sess.run(self.train_op) # Not saved. self.assertEqual(1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) # Simulate 2.5 seconds of sleep. mock_time.return_value = current_time + 2.5 mon_sess.run(self.train_op) # Saved. mock_time.return_value = current_time + 2.6 mon_sess.run(self.train_op) # Not saved. mock_time.return_value = current_time + 2.7 mon_sess.run(self.train_op) # Not saved. self.assertEqual(3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) # Simulate 7.5 more seconds of sleep (10 seconds from start. mock_time.return_value = current_time + 10 mon_sess.run(self.train_op) # Saved. self.assertEqual(6, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def test_saves_when_saver_and_scaffold_both_missing(self): with self.graph.as_default(): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=1) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) self.assertEqual(1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def test_save_steps_saves_in_first_step(self): with self.graph.as_default(): monitor = learn.monitors.CheckpointSaver(self.model_dir, save_steps=2, scaffold=self.scaffold) monitor.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) self._run(monitor, 1, self.train_op, sess) self.assertEqual( 1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def test_save_steps_saves_in_first_step(self): with self.graph.as_default(): hook = basic_session_run_hooks.CheckpointSaverHook( self.model_dir, save_steps=2, scaffold=self.scaffold) hook.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(self.train_op) self.assertEqual( 1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def test_save_saves_at_end(self): with self.graph.as_default(): monitor = learn.monitors.CheckpointSaver( self.model_dir, save_secs=2, scaffold=self.scaffold) monitor.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) self._run(monitor, 1, self.train_op, sess) self._run(monitor, 2, self.train_op, sess) monitor.end(sess) self.assertEqual(2, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def disabled_test_save_secs_saves_periodically(self): with self.graph.as_default(): monitor = learn.monitors.CheckpointSaver(self.model_dir, save_secs=2, scaffold=self.scaffold) monitor.begin() self.scaffold.finalize() with session_lib.Session() as sess: sess.run(self.scaffold.init_op) self._run(monitor, 1, self.train_op, sess) self._run(monitor, 2, self.train_op, sess) # Not saved self.assertEqual( 1, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) time.sleep(2.5) self._run(monitor, 3, self.train_op, sess) # saved self.assertEqual( 3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) self._run(monitor, 4, self.train_op, sess) self._run(monitor, 5, self.train_op, sess) # Not saved self.assertEqual( 3, checkpoint_utils.load_variable(self.model_dir, self.global_step.name)) time.sleep(2.5) self._run(monitor, 6, self.train_op, sess) # saved self.assertEqual( 6, checkpoint_utils.load_variable(self.model_dir, self.global_step.name))
def load_trained_classifier(self, sess, run_id, target_category, iteration): tvars = tf.trainable_variables(scope=self.classifierName) file_path = pathlib.Path(__file__).parent.absolute() model_folder = os.path.join(file_path, "..", "models", target_category) checkpoint_folder = os.path.join( model_folder, "lstm{0}_iteration{1}".format(run_id, iteration)) model_path = os.path.join( checkpoint_folder, "lstm{0}_iteration{1}.ckpt".format(run_id, iteration)) saved_vars = checkpoint_utils.list_variables(checkpoint_dir=model_path) for var in tvars: # assert len([_var for _var in saved_vars if _var.name == var.name]) == 1 # if "Adam" in var.name: # continue var_name = var.name[len(self.classifierName) + 1:] source_array = checkpoint_utils.load_variable( checkpoint_dir=model_path, name=var_name) tf.assign(var, source_array).eval(session=sess)
def load_fp32_weights_into_fp16_vars(checkpoint_path: Path) -> List: """Load fp32 weights from checkpoint path into fp16 variables. Assumes that caller has executed `tf.run(tf.global_variables_initializer())` Args: checkpoint_path: Checkpoint path Returns: Collection of ops to use to restore the weights in the graph. """ checkpoint_variables = [var_name for var_name, _ in list_variables(checkpoint_path)] for graph_var in tf.global_variables(): if graph_var.op.name in checkpoint_variables: var = load_variable(checkpoint_path, graph_var.op.name) weights = tf.cast(var, tf.float16) if var.dtype == np.float32 else var tf.add_to_collection('restore_ops', graph_var.assign(weights)) return tf.get_collection('restore_ops')
def init_graph(): restore_vars = [] layer_num = 0 regex = re.compile('layer_\d+') amaxTotalNum = 0 for name, shape in var_list: var = checkpoint_utils.load_variable(tf.flags.FLAGS.init_checkpoint, name) if "intermediate/dense/kernel" in name and amaxTotalNum == 0: amaxTotalNum = ACTIVATION_AMAX_NUM + 9*shape[0] print(amaxTotalNum, shape[0]) recon_dtype = var.dtype restore_vars.append(tf.get_variable(name, shape=shape, dtype=var.dtype)) tmp = regex.findall(name) if len(tmp) < 1: continue num_tmp = int(tmp[0].replace("layer_", "")) if layer_num < num_tmp: layer_num = num_tmp layer_num = layer_num + 1 #add new var for amax for i in range(layer_num): tf.get_variable("bert/encoder/layer_{}/amaxList".format(i), shape=[amaxTotalNum], dtype=tf.float32) return layer_num, amaxTotalNum, restore_vars
def train(self): if tf.gfile.Exists(self.cluster_conf): cluster_spec, run_config, num_workers = self.gen_run_config() server = tf.train.Server(cluster_spec, job_name=self.job_name, task_index=self.task_index) else: run_config = tf.contrib.learn.RunConfig( save_checkpoints_secs=self.save_checkpoints_secs) num_workers = 1 server = None print ('-' * 40) print ('run_config =', run_config) if self.job_name == 'ps': print ("ps start") server.join() elif self.job_name == 'worker': print ("worker start") train_steps = self.model_parameters.get("train_steps") self.model_fn_param['run_config'] = run_config builder = WideNDeepModelBuilder(model_desc=self.model_desc_obj) model = builder.build_estimator(self.model_fn_param) hooks = [] try: global_step = checkpoint_utils.load_variable( self.checkpoint_path, tf.GraphKeys.GLOBAL_STEP) except: global_step = 0 print ('global_step =', global_step) if self.init_dir and global_step < 100: print ('InitEmbeddings from %s' % self.init_dir) hooks.append(InitEmbeddingsHook(self.init_dir)) model.train(input_fn=self.input_fn_train, steps=train_steps, hooks=hooks)
import tensorflow as tf from tensorflow.python.ops import data_flow_ops from tensorflow.python.framework import dtypes from tensorflow.python.user_ops import user_ops from tensorflow.python.training import queue_runner from tensorflow.contrib.framework.python.framework import checkpoint_utils #print checkpoint_utils.list_variables("hdfs://tensorflow-on-yarn-test/user/yonggang.myg/wnd_plugin_test/kafka_reader_test/checkpoint") checkpoint_utils.list_variables("checkpoint") print checkpoint_utils.load_variable("checkpoint", "queue_runner_checkpoint_var") # tvars = tf.trainable_variables() # tvars_vals = sess.run(tvars) # # for var, val in zip(tvars, tvars_vals): # print(var.name, val)
def testNoCheckpoints(self): checkpoint_dir = self.get_temp_dir() + "/no_checkpoints" with self.assertRaises(errors_impl.OpError): self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "var1"), [])
def covariances(self): """Returns the covariances.""" return checkpoint_utils.load_variable( self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
def checkpoint_quantization(in_checkpoint_file, out_checkpoint_file, per_channel_quantization): var_list = checkpoint_utils.list_variables(tf.flags.FLAGS.init_checkpoint) def init_graph(): restore_vars = [] layer_num = 0 regex = re.compile('layer_\d+') amaxTotalNum = 0 for name, shape in var_list: var = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) if "intermediate/dense/kernel" in name and amaxTotalNum == 0: amaxTotalNum = ACTIVATION_AMAX_NUM + 9 * shape[ 0] + INT8O_GEMM_NUM + TRT_FUSED_MHA_AMAX_NUM print(amaxTotalNum, shape[0]) recon_dtype = var.dtype restore_vars.append( tf.get_variable(name, shape=shape, dtype=var.dtype)) tmp = regex.findall(name) if len(tmp) < 1: continue num_tmp = int(tmp[0].replace("layer_", "")) if layer_num < num_tmp: layer_num = num_tmp layer_num = layer_num + 1 #add new var for amax for i in range(layer_num): tf.get_variable("bert/encoder/layer_{}/amaxList".format(i), shape=[amaxTotalNum], dtype=tf.float32) return layer_num, amaxTotalNum, restore_vars layer_num, amaxTotalNum, restore_vars = init_graph() restorer = tf.train.Saver(restore_vars) saver = tf.train.Saver() config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: restorer.restore(sess, in_checkpoint_file) kernel_name_list = [ "attention/self/query", "attention/self/key", "attention/self/value", "attention/output/dense", "intermediate/dense", "output/dense" ] #input_scale, 0 amax_name_list = [ "attention/self/query/input_quantizer", #Q_aftergemm_scale, 1 "attention/self/query/aftergemm_quantizer", #Qbias_scale, 2 "attention/self/matmul_q_input_quantizer", #K_aftergemm_scale, 3 "attention/self/key/aftergemm_quantizer", #Kbias_scale, 4 "attention/self/matmul_k_input_quantizer", #V_aftergemm_scale, 5 "attention/self/value/aftergemm_quantizer", #Vbias_scale, 6 "attention/self/matmul_v_input_quantizer", #bmm1_scale, 7 "attention/self/softmax_input_quantizer", #Softmax_scale, 8 "attention/self/matmul_a_input_quantizer", #bmm2_scale, 9 "attention/output/dense/input_quantizer", #Proj_aftergemm_scale, 10 "attention/output/dense/aftergemm_quantizer", #ProjBiasNorm_scale, 11 "intermediate/dense/input_quantizer", #FC1_aftergemm_scale, 12 "intermediate/dense/aftergemm_quantizer", #F1Bias_scale, 13 "output/dense/input_quantizer", #FC2_aftergemm_scale, 14 "output/dense/aftergemm_quantizer", #F2Bias_scale, 15 "special_F2Bias_scale", ] int8O_gemm_weight_amax_list = [0 for i in range(INT8O_GEMM_NUM)] #Q_aftergemm int8O_gemm_weight_list = [ "attention/self/query", #K_aftergemm "attention/self/key", #V_aftergemm "attention/self/value", #bmm1_aftergemm "attention/self/matmul_k_input_quantizer", #bmm2_aftergemm "attention/self/matmul_v_input_quantizer", #Proj_aftergemm "attention/output/dense", #FC1_aftergemm "intermediate/dense", #FC2_aftergemm "output/dense" ] int8O_gemm_input_amax_list = [0 for i in range(INT8O_GEMM_NUM)] #Q_aftergemm int8O_gemm_input_list = [ "attention/self/query/input_quantizer", #K_aftergemm "attention/self/key/input_quantizer", #V_aftergemm "attention/self/value/input_quantizer", #bmm1_aftergemm "attention/self/matmul_q_input_quantizer", #bmm2_aftergemm "attention/self/matmul_a_input_quantizer", #Proj_aftergemm "attention/output/dense/input_quantizer", #FC1_aftergemm "intermediate/dense/input_quantizer", #FC2_aftergemm "output/dense/input_quantizer" ] int8O_gemm_output_amax_list = [0 for i in range(INT8O_GEMM_NUM)] #Q_aftergemm int8O_gemm_output_list = [ "attention/self/query/aftergemm_quantizer", #K_aftergemm "attention/self/key/aftergemm_quantizer", #V_aftergemm "attention/self/value/aftergemm_quantizer", #bmm1_aftergemm "attention/self/softmax_input_quantizer", #bmm2_aftergemm "attention/output/dense/input_quantizer", #Proj_aftergemm "attention/output/dense/aftergemm_quantizer", #FC1_aftergemm "intermediate/dense/aftergemm_quantizer", #FC2_aftergemm "output/dense/aftergemm_quantizer" ] factor = 1000000.0 for i in range(layer_num): amaxList = np.zeros([amaxTotalNum]) amax_id = 0 for amax_name in amax_name_list: if amax_name == "special_F2Bias_scale": if i != layer_num - 1: name = "bert/encoder/layer_{}/{}/quant_max:0".format( i + 1, amax_name_list[0]) quant_max = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) name = "bert/encoder/layer_{}/{}/quant_min:0".format( i + 1, amax_name_list[0]) quant_min = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) if abs(quant_max) > abs(quant_min): amax = abs( quant_max) #int(abs(quant_max)*factor)/factor else: amax = abs( quant_min) #int(abs(quant_min)*factor)/factor else: #not used, placeholder amax = 1.0 amaxList[amax_id] = amax amax_id += 1 amaxList[amax_id] = amax / 127.0 amax_id += 1 amaxList[amax_id] = amax / 127.0 / 127.0 amax_id += 1 amaxList[amax_id] = 127.0 / amax amax_id += 1 continue name = "bert/encoder/layer_{}/{}/quant_max:0".format( i, amax_name) quant_max = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) name = "bert/encoder/layer_{}/{}/quant_min:0".format( i, amax_name) quant_min = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) if abs(quant_max) > abs(quant_min): amax = abs(quant_max) #int(abs(quant_max)*factor)/factor else: amax = abs(quant_min) #int(abs(quant_min)*factor)/factor if amax_name in int8O_gemm_input_list: int8O_gemm_input_amax_list[int8O_gemm_input_list.index( amax_name)] = amax if amax_name == "attention/self/query/input_quantizer": int8O_gemm_input_amax_list[int8O_gemm_input_list.index( "attention/self/key/input_quantizer")] = amax int8O_gemm_input_amax_list[int8O_gemm_input_list.index( "attention/self/value/input_quantizer")] = amax if amax_name in int8O_gemm_output_list: int8O_gemm_output_amax_list[int8O_gemm_output_list.index( amax_name)] = amax if amax_name in int8O_gemm_weight_list: int8O_gemm_weight_amax_list[int8O_gemm_weight_list.index( amax_name)] = amax amaxList[amax_id] = amax amax_id += 1 amaxList[amax_id] = amax / 127.0 amax_id += 1 amaxList[amax_id] = amax / 127.0 / 127.0 amax_id += 1 amaxList[amax_id] = 127.0 / amax amax_id += 1 print("done process layer_{} activation amax".format(i)) #kernel amax starts from ACTIVATION_AMAX_NUM amax_id = ACTIVATION_AMAX_NUM for kernel_id, kernel_name in enumerate(kernel_name_list): kernel = tf.get_default_graph().get_tensor_by_name( "bert/encoder/layer_{}/{}/kernel:0".format(i, kernel_name)) name = "bert/encoder/layer_{}/{}/kernel_quantizer/quant_max:0".format( i, kernel_name) quant_max2 = tf.convert_to_tensor( checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name)) name = "bert/encoder/layer_{}/{}/kernel_quantizer/quant_min:0".format( i, kernel_name) quant_min2 = tf.convert_to_tensor( checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name)) kernel_processed, quant_max_processed = transformer_op_module.weight_quantize( kernel, quant_max2, quant_min2, per_channel_quantization=per_channel_quantization) kernel_processed_, quant_max_processed_ = sess.run( [kernel_processed, quant_max_processed]) sess.run(tf.assign(kernel, kernel_processed_)) if kernel_name in int8O_gemm_weight_list: int8O_gemm_weight_amax_list[int8O_gemm_weight_list.index( kernel_name)] = quant_max_processed_[0] for e in quant_max_processed_: amaxList[amax_id] = e amax_id += 1 #for int8O gemm deQuant for j in range(INT8O_GEMM_NUM): amaxList[amax_id] = (int8O_gemm_input_amax_list[j] * int8O_gemm_weight_amax_list[j]) / ( 127.0 * int8O_gemm_output_amax_list[j]) amax_id += 1 #for trt fused MHA amax #### QKV_addBias_amax amaxList[amax_id] = np.maximum( np.maximum(amaxList[8], amaxList[16]), amaxList[24]) amax_id += 1 #### softmax amax amaxList[amax_id] = amaxList[32] amax_id += 1 #### bmm2 amax amaxList[amax_id] = amaxList[36] amax_id += 1 amaxL = tf.get_default_graph().get_tensor_by_name( "bert/encoder/layer_{}/amaxList:0".format(i)) sess.run(tf.assign(amaxL, amaxList)) print("done process layer_{} kernel weight".format(i)) saver.save(sess, out_checkpoint_file)
def load_variable(checkpoint_dir, name): """See `tf.contrib.framework.load_variable`.""" return checkpoint_utils.load_variable(checkpoint_dir, name)
def clusters(self): """Returns cluster centers.""" clusters = checkpoint_utils.load_variable( self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE) return np.squeeze(clusters, 1)
def init_graph(): for name, shape in var_list: var = checkpoint_utils.load_variable( tf.flags.FLAGS.init_checkpoint, name) recon_dtype = tf.float16 if var.dtype == np.float32 else var.dtype tf.get_variable(name, shape=shape, dtype=recon_dtype)
def testNoCheckpoints(self): checkpoint_dir = self.get_temp_dir() + "/no_checkpoints" with self.assertRaises(errors_impl.OpError): self.assertAllEqual( checkpoint_utils.load_variable(checkpoint_dir, "var1"), [])
def weights(self): """Returns the cluster weights.""" return checkpoint_utils.load_variable( self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
def _get_checkpointed_vars(checkpoints, name, fn=lambda x: x): return { k: fn(checkpoint_utils.load_variable(checkpoint_dir=k, name=name)) for k in checkpoints }
def covariances(self): """Returns the covariances.""" return checkpoint_utils.load_variable( self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
def clusters(self): """Returns cluster centers.""" clusters = checkpoint_utils.load_variable( self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE) return np.squeeze(clusters, 1)
def weights(self): """Returns the cluster weights.""" return checkpoint_utils.load_variable( self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
def load_variable(checkpoint_dir, name): """See `tf.contrib.framework.load_variable`.""" return checkpoint_utils.load_variable(checkpoint_dir, name)