def get_init_values_for_pruned_layers(prune_scopes, shorten_scopes, kept_percentage, prune_info): """ prune layers iteratively so prune_scopes and shorten scopes should be of size one. """ graph = tf.Graph() with graph.as_default(): deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train', FLAGS.dataset_dir) batch_queue = train_inputs(dataset, deploy_config, FLAGS) images, _ = batch_queue.dequeue() # set the scope scope = FLAGS.net_name_scope_pruned if FLAGS.stage_index == 0: scope = FLAGS.net_name_scope_checkpoint network_fn = nets_factory.get_network_fn_pruned( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, prune_info=prune_info) network_fn(images, is_training=False, scope=scope) with tf.Session() as sess: load_checkpoint(sess, FLAGS.checkpoint_path) print("get pruned kernel matrix with kept_percentage", kept_percentage) variables_init_value = get_pruned_kernel_matrix( sess, prune_scopes, shorten_scopes, kept_percentage) # remove graph del graph return variables_init_value
def main(_): tic = time.time() print('tensorflow version:', tf.__version__) tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') # init net_name_scope_pruned = FLAGS.net_name_scope_pruned net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint if FLAGS.stage_index: net_name_scope_checkpoint = net_name_scope_pruned # choose the pruning option for the stage kp = choose_kp() pre_config = load_checkpoint_config() print('Load previous config:', len(pre_config), pre_config) fake_config, cur_config, new_prune_scopes, new_kept_percentage = generate_config( pre_config, kp) print('Generate fake_config:', fake_config) print('Generate config:', len(cur_config), cur_config) print('new prune scopes:', len(new_prune_scopes), new_prune_scopes) pre_prune_scopes, pre_kept_percentage = [], [] for i, kp in enumerate(pre_config): if kp == 1.0: continue pre_kept_percentage.extend([kp, kp]) pre_prune_scopes.extend(valid_indexed_prune_scopes_for_units[i]) cur_prune_scopes, cur_kept_percentage = [], [] for i, kp in enumerate(cur_config): if kp == 1.0: continue cur_kept_percentage.extend([kp, kp]) cur_prune_scopes.extend(valid_indexed_prune_scopes_for_units[i]) print("cur_prune_scopes", len(cur_prune_scopes), cur_prune_scopes) print("cur_kept_percentage", len(cur_kept_percentage), cur_kept_percentage) # prepare for training with the specific config pre_prune_info = indexed_prune_scopes_to_prune_info( pre_prune_scopes, pre_kept_percentage) cur_prune_info = indexed_prune_scopes_to_prune_info( cur_prune_scopes, cur_kept_percentage) print("pre_prune_info") pprint(pre_prune_info) print("cur_prune_info") pprint(cur_prune_info) # prepare file system results_dir = os.path.join(FLAGS.train_dir, "_".join(map( str, fake_config))) #+'_'+str(FLAGS.max_number_of_steps)) train_dir = os.path.join(results_dir, 'train_lr' + str(FLAGS.learning_rate)) print('train_dir:', train_dir) if (not FLAGS.continue_training) or ( not tf.train.latest_checkpoint(train_dir)): prune_scopes = indexed_prune_scopes_to_prune_scopes( new_prune_scopes, net_name_scope_checkpoint) shorten_scopes = indexed_prune_scopes_to_shorten_scopes( new_prune_scopes, net_name_scope_checkpoint) print("prune_scopes", len(prune_scopes), "\n\t" + "\n\t".join(prune_scopes)) print("shorten_scopes", len(shorten_scopes), "\n\t" + "\n\t".join(shorten_scopes)) # exit() variables_init_value = get_init_values_for_pruned_layers( prune_scopes, shorten_scopes, new_kept_percentage, pre_prune_info) # print("variables_init_value:", len(variables_init_value.keys()), variables_init_value.keys()) reinit_scopes = [ re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v) for v in prune_scopes + shorten_scopes ] prepare_file_system(train_dir) def write_detailed_info(info): with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f: f.write(info + '\n') info = 'train_dir:' + train_dir + '\n' info += 'learning_rate:' + str(FLAGS.learning_rate) + '\n' info += 'stage_index:' + str(FLAGS.stage_index) + '\n' info += 'configuration: ' + str(cur_config) + '\n' info += 'cur_prune_scopes: ' + str(cur_prune_scopes) + '\n' info += 'kept_percentage: ' + str(cur_kept_percentage) print(info) write_detailed_info(info) def write_log(info): with open(os.path.join(results_dir, 'log.txt'), 'a') as f: f.write(info) f.write('\n') # write current config into file with open(os.path.join(train_dir, "config.txt"), 'w') as f: f.write(", ".join(map(str, cur_config))) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.train_dataset_name, FLAGS.dataset_dir) test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.test_dataset_name, FLAGS.dataset_dir) batch_queue = train_inputs(dataset, deploy_config, FLAGS) test_images, test_labels = test_inputs(test_dataset, deploy_config, FLAGS) images, labels = batch_queue.dequeue() ###################### # Select the network# ###################### network_fn_pruned = nets_factory.get_network_fn_pruned( FLAGS.model_name, prune_info=cur_prune_info, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay) #################### # Define the model # #################### logits_train, _ = network_fn_pruned(images, is_training=True, is_local_train=False, reuse_variables=False, scope=net_name_scope_pruned) logits_eval, _ = network_fn_pruned(test_images, is_training=False, is_local_train=False, reuse_variables=True, scope=net_name_scope_pruned) cross_entropy = add_cross_entropy(logits_train, labels) correct_prediction = add_correct_prediction(logits_eval, test_labels) ############################# # Specify the loss function # ############################# tf.add_to_collection('subgraph_losses', cross_entropy) # get regularization loss regularization_losses = get_regularization_losses_within_scopes() print_list('regularization_losses', regularization_losses) # total loss and its summary total_loss = tf.add_n(tf.get_collection('subgraph_losses'), name='total_loss') for l in tf.get_collection('subgraph_losses') + [total_loss]: tf.summary.scalar(l.op.name + '/summary', l) ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.variables_device()): global_step = tf.Variable(0, trainable=False, name='global_step') with tf.device(deploy_config.optimizer_device()): learning_rate = configure_learning_rate(dataset.num_samples, global_step, FLAGS) optimizer = configure_optimizer(learning_rate, FLAGS) tf.summary.scalar('learning_rate', learning_rate) ############################# # Add train operation # ############################# variables_to_train = get_trainable_variables_within_scopes() train_op = add_train_op(optimizer, total_loss, global_step, var_list=variables_to_train) print_list("variables_to_train", variables_to_train) # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned. update_ops = get_update_ops_within_scopes() print_list("update_ops", update_ops) # add train_tensor update_ops.append(train_op) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # add summary op summary_op = tf.summary.merge_all() print("HG: trainable_variables=", len(tf.trainable_variables())) print("HG: model_variables=", len(tf.model_variables())) print("HG: global_variables=", len(tf.global_variables())) sess_config = tf.ConfigProto(intra_op_parallelism_threads=16, inter_op_parallelism_threads=16) with tf.Session(config=sess_config) as sess: ########################### # Prepare for filewriter. # ########################### train_writer = tf.summary.FileWriter(train_dir, sess.graph) # if restart the training or there is no checkpoint in the train_dir if (not FLAGS.continue_training) or ( not tf.train.latest_checkpoint(train_dir)): ######################################### # Reinit pruned model variable # ######################################### variables_to_reinit = get_model_variables_within_scopes( reinit_scopes) print_list("Initialize pruned variables", variables_to_reinit) assign_ops = [] for v in variables_to_reinit: key = re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name) if key in variables_init_value: value = variables_init_value.get(key) # print(key, value) assign_ops.append( tf.assign(v, tf.convert_to_tensor(value), validate_shape=True)) # v.set_shape(value.shape) else: raise ValueError( "Key not in variables_init_value, key=", key) assign_op = tf.group(*assign_ops) sess.run(assign_op) ################################################# # Restore unchanged model variable. # ################################################# variables_to_restore = { re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name): v for v in get_model_variables_within_scopes() if v not in variables_to_reinit } print_list("restore model variables", variables_to_restore.values()) load_checkpoint(sess, FLAGS.checkpoint_path, var_list=variables_to_restore) else: ########################################### ## Restore all variables from checkpoint ## ########################################### variables_to_restore = get_global_variables_within_scopes() load_checkpoint(sess, train_dir, var_list=variables_to_restore) ################################################# # init unitialized global variable. # ################################################# variables_to_init = get_global_variables_within_scopes( sess.run(tf.report_uninitialized_variables())) print_list("init unitialized variables", variables_to_init) sess.run(tf.variables_initializer(variables_to_init)) init_global_step_value = sess.run(global_step) print('initial global step: ', init_global_step_value) if init_global_step_value >= FLAGS.max_number_of_steps: print('\nExit: init_global_step_value (%d) >= FLAG.max_number_of_steps (%d)\n\n' \ %(init_global_step_value, FLAGS.max_number_of_steps)) return ########################### # Record CPU usage # ########################### # mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log") # os.system("mpstat -P ALL 1 > " + mpstat_output_filename + " 2>&1 &") ########################### # Kicks off the training. # ########################### coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) print('HG: # of threads=', len(threads)) duration = 0 duration_cnt = 0 train_time = 0 train_only_cnt = 0 print("start to train at:", datetime.now()) for i in range(init_global_step_value, FLAGS.max_number_of_steps + 1): # run optional meta data, or summary, while run train tensor #if i < FLAGS.max_number_of_steps: if i > init_global_step_value: # train while run metadata if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1: run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() loss_value = sess.run(train_tensor, options=run_options, run_metadata=run_metadata) train_writer.add_run_metadata(run_metadata, 'step%d-train' % i) # Create the Timeline object, and write it to a json file fetched_timeline = timeline.Timeline( run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format( ) with open( os.path.join(train_dir, 'timeline_' + str(i) + '.json'), 'w') as f: f.write(chrome_trace) # train while record summary elif i % FLAGS.summary_every_n_steps == 0: train_summary, loss_value = sess.run( [summary_op, train_tensor]) train_writer.add_summary(train_summary, i) # train only else: start_time = time.time() loss_value = sess.run(train_tensor) train_only_cnt += 1 train_time += time.time() - start_time duration_cnt += 1 duration += time.time() - start_time # log loss information if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0: log_frequency = duration_cnt examples_per_sec = log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / log_frequency) summary = tf.Summary() summary.value.add(tag='examples_per_sec', simple_value=examples_per_sec) summary.value.add(tag='sec_per_batch', simple_value=sec_per_batch) train_writer.add_summary(summary, i) format_str = ( '%s: step %d, loss = %.3f (%.1f examples/sec; %.3f sec/batch)' ) print(format_str % (datetime.now(), i, loss_value, examples_per_sec, sec_per_batch)) duration = 0 duration_cnt = 0 info = format_str % (datetime.now(), i, loss_value, examples_per_sec, sec_per_batch) write_detailed_info(info) else: # run only total loss when i=0 train_summary, loss_value = sess.run( [summary_op, total_loss]) #loss_value = sess.run(total_loss) train_writer.add_summary(train_summary, i) format_str = ('%s: step %d, loss = %.3f') print(format_str % (datetime.now(), i, loss_value)) info = format_str % (datetime.now(), i, loss_value) write_detailed_info(info) # record the evaluation accuracy is_last_step = (i == FLAGS.max_number_of_steps) if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step: #run_meta = (i==FLAGS.evaluate_every_n_steps) test_accuracy, run_metadata = evaluate_accuracy( sess, coord, test_dataset.num_samples, test_images, test_labels, test_images, test_labels, correct_prediction, FLAGS.test_batch_size, run_meta=False) summary = tf.Summary() summary.value.add(tag='accuracy', simple_value=test_accuracy) train_writer.add_summary(summary, i) #if run_meta: # eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i) info = ('%s: step %d, test_accuracy = %.6f') % ( datetime.now(), i, test_accuracy) print(info) write_detailed_info(info) ########################### # Save model parameters . # ########################### #saver = tf.train.Saver(var_list=get_model_variables_within_scopes([net_name_scope_pruned+'/'])) save_path = saver.save( sess, os.path.join(train_dir, 'model.ckpt-' + str(i))) print("HG: Model saved in file: %s" % save_path) # check accuracy if test_accuracy >= FLAGS.baseline_accuracy - FLAGS.drop_rate: print( '\nStop training becuase accuracy threshold (%f) reached!\n\n' % (FLAGS.baseline_accuracy - FLAGS.drop_rate)) write_detailed_info( 'Stop training becuase accuracy threshold (%f) reached!' % (FLAGS.baseline_accuracy - FLAGS.drop_rate)) break coord.request_stop() coord.join(threads) total_time = time.time() - tic train_speed = train_time * 1.0 / train_only_cnt train_time = train_speed * ( FLAGS.max_number_of_steps ) # - init_global_step_value) #/train_only_cnt info = "HG: training speed(sec/batch): %.6f\n" % (train_speed) info += "HG: training time(min): %.1f, total time(min): %.1f" % ( train_time / 60.0, total_time / 60.0) print(info) write_detailed_info(info)
def main(_): tic = time.time() print('tensorflow version:', tf.__version__) tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') # initialize constants net_name_scope_pruned = FLAGS.net_name_scope_pruned net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint kp_options = sorted([float(x) for x in FLAGS.kept_percentages.split(',')]) num_blocks = int(len(valid_layer_names) / FLAGS.block_size) if len(valid_layer_names) % FLAGS.block_size != 0: print('ERROR: len(valid_layer_names)%FLAGS.block_size!=0') return # if FLAGS.block_id >= num_blocks: # print('ERROR: block_id=%d should be smaller than the number of blocks=%d' %(FLAGS.block_id, num_blocks)) # return print('HG: kp_options', kp_options) print('HG: block size:', FLAGS.block_size) print('HG: number of blocks:', num_blocks) #, ', block_id:', FLAGS.block_id) if len(kp_options) > 1: print('ERROR: only support kp options = 1') return # prepare file system # block_config_str = '_'.join(map(str, block_config)) if FLAGS.last_conv_pruned: foldername = 'last_conv_pruned' else: foldername = 'last_conv_unpruned' results_dir = os.path.join(FLAGS.train_dir, foldername, 'kp' + str(FLAGS.kept_percentages)) train_dir = os.path.join(results_dir, 'train') print('HG: train_dir:', train_dir) if not (FLAGS.continue_training and tf.train.latest_checkpoint(train_dir)): prepare_file_system(train_dir) def write_log_info(info): with open(os.path.join(FLAGS.train_dir, 'log.txt'), 'a') as f: f.write(info + '\n') def write_detailed_info(info): with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f: f.write(info + '\n') info = 'train_dir:' + train_dir + '\n' info += 'kp_options:' + str(kp_options) + '\n' log_info = info + '\n' write_detailed_info(info) # set prune info prune_info = {} block_layer_names_list = [] pruned_layer_names_list = [] block_config_list = [] prune_scopes = [] shorten_scopes = [] network_config = [] for block_id in xrange(num_blocks): #------------------------ # get block layer names #------------------------ start_layer_id = FLAGS.block_size * block_id end_layer_id = start_layer_id + FLAGS.block_size block_layer_names = valid_layer_names[start_layer_id:end_layer_id] is_last_block = valid_layer_names[-1] in block_layer_names # print('HG: block_layer_names:', block_layer_names) # print('HG: is_last_block:', is_last_block) block_layer_names_list.append(block_layer_names) #------------------------ # get pruned layer names so that the pruned block can fit into the test network. #------------------------ pruned_layer_names = valid_layer_names[start_layer_id:end_layer_id] config_length = FLAGS.block_size # note that last layer cannot be pruned. if is_last_block: pruned_layer_names.remove(valid_layer_names[-1]) config_length -= 1 # if the block is not the first block, prune also the layer before the block if block_id != 0: pruned_layer_names = [valid_layer_names[start_layer_id - 1] ] + pruned_layer_names config_length += 1 # print('HG: pruned_layer_names:', pruned_layer_names) pruned_layer_names_list.append(pruned_layer_names) #------------------------ # get the pruning configuration for the block #------------------------ # given N=#options, m=block_size. first block has #configs=N^m, other blocks have #configs=N^{m+1}, block_configurations = list( itertools.product(kp_options, repeat=config_length)) # print('HG: number of block variants:', len(block_configurations)) if FLAGS.block_config_id >= len(block_configurations): print( 'ERROR: block_config_id=%d should be smaller than number of block variants=%d' % (FLAGS.block_config_id, len(block_configurations))) return block_config = list(block_configurations[FLAGS.block_config_id]) if not FLAGS.last_conv_pruned: # fix the input layer to the block to be 1.0 if block_id != 0: block_config[0] = 1.0 # fix the last conv layer in the block to be 1.0 if not is_last_block: block_config[-1] = 1.0 # block_config=[0.5, 0.5, 0.5] # print('HG: block confiugrations:', block_config) block_config_list.append(block_config) #------------------------ # prepare prune_info with the config #------------------------ for i in xrange(len(block_config)): layer_name = pruned_layer_names[i] if layer_name in prune_info: if prune_info[layer_name]['kp'] != block_config[i]: print( "ERROR: layer name already in prune info but prune_info[layer_name]['kp'] != block_config[i]" ) return prune_info[layer_name] = {'kp': block_config[i]} #------------------------ # get pruned scopes and shorten scopes #------------------------ for layer_name, layer_config in zip(pruned_layer_names, block_config): prune_scope = layer_name_to_prune_scope(layer_name, net_name_scope_checkpoint) if prune_scope not in prune_scopes: prune_scopes.append(prune_scope) shorten_scopes.append( layer_name_to_shorten_scope(layer_name, net_name_scope_checkpoint)) network_config.append(layer_config) #------------------------ # get variables init value and pruned filter indexes #------------------------ print('HG: prune_scopes', len(prune_scopes), prune_scopes) print('HG: network_config', network_config) variables_init_value, pruned_filter_indexes = get_init_values_for_pruned_layers( prune_scopes, shorten_scopes, network_config) # print('HG: prune_info') # pprint(prune_info) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.train_dataset_name, FLAGS.dataset_dir) test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.test_dataset_name, FLAGS.dataset_dir) batch_queue = train_inputs(dataset, deploy_config, FLAGS) test_images, test_labels = test_inputs(test_dataset, deploy_config, FLAGS) images, labels = batch_queue.dequeue() ###################### # Select the network# ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay) _, end_points = network_fn(images, is_training=False) # for item in end_points.iteritems(): # print(item) # return network_fn_pruned = nets_factory.get_network_fn_pruned( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay) # ######################################### # # Configure the optimization procedure. # # ######################################### with tf.device(deploy_config.variables_device()): global_step = tf.Variable(0, trainable=False, name='global_step') with tf.device(deploy_config.optimizer_device()): learning_rate = configure_learning_rate(dataset.num_samples, global_step, FLAGS) optimizer = configure_optimizer(learning_rate, FLAGS) tf.summary.scalar('learning_rate', learning_rate) #################### # Define the model # #################### # prune_info = {layer_name_1: {'kp': 0.3, 'inputs': inputs}, layer_name_2:{'kp':0.5}} # checkpoint_prune_info = {pruned_layer_names[-1]:{'kp':block_config[-1]}} # _, end_points = network_fn_pruned(images, # prune_info = checkpoint_prune_info, # is_training=True, # is_local_train=False, # reuse_variables=False, # scope = net_name_scope_checkpoint) # -------------------------- # set inputs for prune_info # -------------------------- print_list('pruned_layer_names_list', pruned_layer_names_list) for block_id in xrange(num_blocks): block_layer_names = block_layer_names_list[block_id] pruned_layer_names = pruned_layer_names_list[block_id] if block_id == 0: block_inputs = images else: # original inputs might have a different dimension with the required inputs dimension. # use pruned_filter_indexes to prune original_inputs with tf.name_scope('inputs_selector_' + str(block_id)): # get pruned filter indexes prune_scope = net_name_scope_checkpoint + '/' + pruned_layer_names[ 0] filter_indexes = pruned_filter_indexes[prune_scope] # get original inputs inputs_layer_id = all_layer_names.index( block_layer_names[0]) - 1 inputs_scope = net_name_scope_checkpoint + '/' + all_layer_names[ inputs_layer_id] original_inputs = end_points[inputs_scope] # downsample original inputs num_dim = original_inputs.get_shape().as_list()[-1] block_inputs = tf.stack([ original_inputs[:, :, :, i] for i in xrange(num_dim) if i not in filter_indexes ], axis=-1) print('HG: inputs_scope:', inputs_scope, original_inputs.get_shape().as_list(), ' to ', block_inputs.get_shape().as_list()) # set inputs for this block prune_info[block_layer_names[0]]['inputs'] = block_inputs print('HG: prune_info:') pprint(prune_info) # generate the pruned network for training _, end_points_pruned = network_fn_pruned(images, prune_info=prune_info, is_training=True, is_local_train=True, reuse_variables=False, scope=net_name_scope_pruned) # generate the pruned network for testing logits, _ = network_fn_pruned(test_images, prune_info=prune_info, is_training=False, is_local_train=False, reuse_variables=True, scope=net_name_scope_pruned) # add correct prediction to the testing network correct_prediction = add_correct_prediction(logits, test_labels) ############################# # Specify the loss functions # ############################# total_losses = [] train_tensors = [] for block_id in xrange(num_blocks): print('\n\nblock_id', block_id) block_layer_names = block_layer_names_list[block_id] pruned_layer_names = pruned_layer_names_list[block_id] is_last_block = (block_id == num_blocks - 1) outputs_scope = net_name_scope_checkpoint + '/' + block_layer_names[ -1] outputs_scope_pruned = net_name_scope_pruned + '/' + block_layer_names[ -1] print('HG: outputs_scope:', outputs_scope) # add reconstruction loss collection_name = 'subgraph_losses_' + str(block_id) if outputs_scope not in end_points: raise ValueError( 'end_points does not contain the outputs_scope: %s', outputs_scope) outputs = end_points[outputs_scope] if outputs_scope_pruned not in end_points_pruned: raise ValueError( 'end_points_pruned does not contain the outputs_scope_pruned: %s', outputs_scope_pruned) outputs_pruned = end_points_pruned[outputs_scope_pruned] # TODO: cannot use l2_loss directory since the outputs and outputs_pruned do not have the same dimension. if is_last_block: outputs_gnd = outputs else: with tf.name_scope('output_selector_' + str(block_id)): filter_indexes = pruned_filter_indexes[outputs_scope] num_dim = outputs.get_shape().as_list()[-1] outputs_gnd = tf.stack([ outputs[:, :, :, i] for i in xrange(num_dim) if i not in filter_indexes ], axis=-1) print('HG: ouputs selector:', outputs.get_shape().as_list(), ' to ', outputs_gnd.get_shape().as_list()) l2_loss = add_l2_loss(outputs_gnd, outputs_pruned, add_to_collection=True, collection_name=collection_name) # get regularization loss train_scopes = [ net_name_scope_pruned + '/' + item for item in block_layer_names ] print_list('train_scopes', train_scopes) regularization_losses = get_regularization_losses_within_scopes( train_scopes, add_to_collection=True, collection_name=collection_name) print_list('regularization_losses', regularization_losses) # total loss and its summary total_loss = tf.add_n(tf.get_collection(collection_name), name='total_loss') for l in tf.get_collection(collection_name) + [total_loss]: tf.summary.scalar(l.op.name + '/summary', l) total_losses.append(total_loss) ############################# # Add train operation # ############################# variables_to_train = get_trainable_variables_within_scopes( train_scopes) print_list("variables_to_train", variables_to_train) train_op = add_train_op(optimizer, total_loss, global_step, var_list=variables_to_train) with tf.control_dependencies([train_op]): train_tensor = tf.identity(total_loss, name='train_op') train_tensors.append(train_tensor) # add summary op summary_op = tf.summary.merge_all() print("HG: trainable_variables=", len(tf.trainable_variables())) print("HG: model_variables=", len(tf.model_variables())) print("HG: global_variables=", len(tf.global_variables())) print_list( 'model_variables but not trainable variables', list( set(tf.model_variables()).difference( tf.trainable_variables()))) print_list( 'global_variables but not model variables', list(set(tf.global_variables()).difference(tf.model_variables()))) print( "HG: trainable_variables from " + net_name_scope_checkpoint + "=", len( get_trainable_variables_within_scopes( [net_name_scope_checkpoint + '/']))) print( "HG: trainable_variables from " + net_name_scope_pruned + "=", len( get_trainable_variables_within_scopes( [net_name_scope_pruned + '/']))) print( "HG: model_variables from " + net_name_scope_checkpoint + "=", len( get_model_variables_within_scopes( [net_name_scope_checkpoint + '/']))) print( "HG: model_variables from " + net_name_scope_pruned + "=", len( get_model_variables_within_scopes( [net_name_scope_pruned + '/']))) print( "HG: global_variables from " + net_name_scope_checkpoint + "=", len( get_global_variables_within_scopes( [net_name_scope_checkpoint + '/']))) print( "HG: global_variables from " + net_name_scope_pruned + "=", len( get_global_variables_within_scopes( [net_name_scope_pruned + '/']))) sess_config = tf.ConfigProto(intra_op_parallelism_threads=16, inter_op_parallelism_threads=16) with tf.Session(config=sess_config) as sess: ########################### # prepare for filewritter # ########################### train_writer = tf.summary.FileWriter(train_dir, sess.graph) if not (FLAGS.continue_training and tf.train.latest_checkpoint(train_dir)): ########################################### # Restore original model variable values. # ########################################### variables_to_restore = get_model_variables_within_scopes( [net_name_scope_checkpoint + '/']) print_list("restore model variables for original", variables_to_restore) load_checkpoint(sess, FLAGS.checkpoint_path, var_list=variables_to_restore) ################################################# # Init pruned networks with well-trained model # ################################################# variables_to_reinit = get_model_variables_within_scopes( [net_name_scope_pruned + '/']) print_list("init pruned model variables for pruned network", variables_to_reinit) assign_ops = [] for v in variables_to_reinit: key = re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name) if key in variables_init_value: value = variables_init_value.get(key) # print(key, value) assign_ops.append( tf.assign(v, tf.convert_to_tensor(value), validate_shape=True)) # v.set_shape(value.shape) else: raise ValueError( "Key not in variables_init_value, key=", key) assign_op = tf.group(*assign_ops) sess.run(assign_op) else: # restore all variables from checkpoint variables_to_restore = get_global_variables_within_scopes() load_checkpoint(sess, train_dir, var_list=variables_to_restore) ################################################# # init unitialized global variable. # ################################################# variables_to_init = get_global_variables_within_scopes( sess.run(tf.report_uninitialized_variables())) print_list("init uninitialized_variables", variables_to_init) sess.run(tf.variables_initializer(variables_to_init)) init_global_step_value = sess.run(global_step) print('initial global step: ', init_global_step_value) if init_global_step_value >= FLAGS.max_number_of_steps: print('Exit: init_global_step_value (%d) >= FLAGS.max_number_of_steps (%d)' \ %(init_global_step_value, FLAGS.max_number_of_steps)) return ########################### # Record CPU usage # ########################### # mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log") # os.system("mpstat -P ALL 1 > " + mpstat_output_filename + " 2>&1 &") ########################### # Kicks off the training. # ########################### coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) print('HG: # of threads=', len(threads)) # saver for models if FLAGS.max_to_keep <= 0: max_to_keep = int(2 * FLAGS.max_number_of_steps / FLAGS.evaluate_every_n_steps) else: max_to_keep = FLAGS.max_to_keep saver = tf.train.Saver(max_to_keep=max_to_keep) train_time = 0 # the amount of time spending on sgd training only. duration = 0 # used to estimate the training speed train_only_cnt = 0 # used to calculate the true training time. duration_cnt = 0 print("start to train at:", datetime.now()) for i in range(init_global_step_value, FLAGS.max_number_of_steps + 1): # run optional meta data, or summary, while run train tensor if i > init_global_step_value: # FLAGS.max_number_of_steps: # run metadata if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1: run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() loss_values = sess.run(train_tensors, options=run_options, run_metadata=run_metadata) train_writer.add_run_metadata(run_metadata, 'step%d-train' % i) # Create the Timeline object, and write it to a json file fetched_timeline = timeline.Timeline( run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format( ) with open( os.path.join(train_dir, 'timeline_' + str(i) + '.json'), 'w') as f: f.write(chrome_trace) # record summary elif i % FLAGS.summary_every_n_steps == 0: results = sess.run([summary_op] + train_tensors) train_summary, loss_values = results[0], results[-1] train_writer.add_summary(train_summary, i) # print('HG: train with summary') # only run train op else: start_time = time.time() loss_values = sess.run(train_tensors) train_only_cnt += 1 duration_cnt += 1 train_time += time.time() - start_time duration += time.time() - start_time if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0: # record speed log_frequency = duration_cnt examples_per_sec = log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / log_frequency) summary = tf.Summary() summary.value.add(tag='examples_per_sec', simple_value=examples_per_sec) summary.value.add(tag='sec_per_batch', simple_value=sec_per_batch) train_writer.add_summary(summary, i) info = ( '%s: step %d, loss = %s (%.1f examples/sec; %.3f sec/batch)' ) % (datetime.now(), i, str(loss_values), examples_per_sec, sec_per_batch) print(info) duration = 0 duration_cnt = 0 write_detailed_info(info) else: # run only total loss when i=0 results = sess.run( [summary_op] + total_losses) #loss_value = sess.run(total_loss) train_summary, loss_values = results[0], results[-1] train_writer.add_summary(train_summary, i) format_str = ('%s: step %d, loss = %s') print(format_str % (datetime.now(), i, str(loss_values))) info = format_str % (datetime.now(), i, str(loss_values)) write_detailed_info(info) # record the evaluation accuracy is_last_step = (i == FLAGS.max_number_of_steps) if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step: test_accuracy, run_metadata = evaluate_accuracy( sess, coord, test_dataset.num_samples, test_images, test_labels, test_images, test_labels, correct_prediction, FLAGS.test_batch_size, run_meta=False) summary = tf.Summary() summary.value.add(tag='accuracy', simple_value=test_accuracy) train_writer.add_summary(summary, i) # if run_meta: # eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i) info = ('%s: step %d, test_accuracy = %s') % ( datetime.now(), i, str(test_accuracy)) print(info) if i == init_global_step_value or is_last_step: # write_log_info(info) log_info += info + '\n' write_detailed_info(info) ########################### # Save model parameters . # ########################### # saver = tf.train.Saver() save_path = saver.save( sess, os.path.join(train_dir, 'model.ckpt-' + str(i))) print("HG: Model saved in file: %s" % save_path) coord.request_stop() coord.join(threads) total_time = time.time() - tic train_speed = train_time * 1.0 / train_only_cnt train_time = train_speed * (FLAGS.max_number_of_steps) info = "HG: training speed(sec/batch): %.6f\n" % (train_speed) info += "HG: training time(min): %.1f, total time(min): %.1f" % ( train_time / 60.0, total_time / 60.0) print(info) log_info += info + '\n\n' write_log_info(log_info) write_detailed_info(info)
def main(_): tic = time.time() tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') # init net_name_scope_pruned = FLAGS.net_name_scope_pruned net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint block_names = valid_block_names kept_percentages_dict = get_kept_percentages_dict_from_path( FLAGS.checkpoint_path) kept_percentages = sorted(map(float, FLAGS.kept_percentages.split(','))) # check networks with the kps are pre-trained. for kp in kept_percentages: if kp not in kept_percentages_dict: raise Error('kept_percentage=' + str(kp) + ' not in folder:' + FLAGS.checkpoint_path) num_options = len(kept_percentages) num_units = len(block_names) print('num_options=%d, num_blocks=%d' % (num_options, num_units)) print('HG: total number of configurations=%d' % (num_options**num_units)) if FLAGS.configuration_type == 'sample': configs = get_sampled_configurations(num_units, num_options, FLAGS.total_num_configurations) elif FLAGS.configuration_type == 'special': configs = get_special_configurations(num_units, num_options) num_configurations = len(configs) #Getting MPI rank integer comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank >= num_configurations: print("ERROR: rank(%d) > num_configurations(%d)" % (rank, num_configurations)) return FLAGS.configuration_index = FLAGS.start_configuration_index + rank config = configs[FLAGS.configuration_index] print('HG: kept_percentages=%s, start_config_index=%d, num_configs=%d, rank=%d, config_index=%d' \ %(str(kept_percentages), FLAGS.start_configuration_index, num_configurations, rank, FLAGS.configuration_index)) # prepare for training with the specific config kept_percentage = config_to_kept_percentage_sequence( config, block_names, kept_percentages) prune_info = kept_percentage_sequence_to_prune_info( kept_percentage, block_names) print('HG: prune_info:') pprint(prune_info) # prepare file system results_dir = os.path.join( FLAGS.train_dir, "id" + str(FLAGS.configuration_index)) #+'_'+str(FLAGS.max_number_of_steps)) train_dir = os.path.join(results_dir, 'train') if (not FLAGS.continue_training) or ( not tf.train.latest_checkpoint(train_dir)): print('Start a new training') prepare_file_system(train_dir) else: print('Continue training') def write_detailed_info(info): with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f: f.write(info + '\n') info = 'train_dir: ' + train_dir + '\n' info += 'options:' + str(kept_percentages) + '\n' info += 'configuration: ' + str(config) + '\n' info += 'kept_percentage: ' + str(kept_percentage) print(info) write_detailed_info(info) with tf.Graph().as_default(): deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.train_dataset_name, FLAGS.dataset_dir) test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.test_dataset_name, FLAGS.dataset_dir) batch_queue = train_inputs(dataset, deploy_config, FLAGS) test_images, test_labels = test_inputs(test_dataset, deploy_config, FLAGS) images, labels = batch_queue.dequeue() ###################### # Select the network# ###################### network_fn_pruned = nets_factory.get_network_fn_pruned( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay) #################### # Define the model # #################### logits_train, _ = network_fn_pruned(images, prune_info=prune_info, is_training=True, is_local_train=False, reuse_variables=False, scope=net_name_scope_pruned) logits_eval, _ = network_fn_pruned(test_images, prune_info=prune_info, is_training=False, is_local_train=False, reuse_variables=True, scope=net_name_scope_pruned) cross_entropy = add_cross_entropy(logits_train, labels) correct_prediction = add_correct_prediction(logits_eval, test_labels) ############################# # Specify the loss functions # ############################# collection_name = 'subgraph_losses' tf.add_to_collection(collection_name, cross_entropy) # get regularization loss regularization_losses = get_regularization_losses_within_scopes() print_list('regularization_losses', regularization_losses) # total loss and its summary total_loss = tf.add_n(tf.get_collection(collection_name), name='total_loss') for l in tf.get_collection(collection_name) + [total_loss]: tf.summary.scalar(l.op.name + '/summary', l) ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.variables_device()): global_step = tf.Variable(0, trainable=False, name='global_step') with tf.device(deploy_config.optimizer_device()): learning_rate = configure_learning_rate(dataset.num_samples, global_step, FLAGS) optimizer = configure_optimizer(learning_rate, FLAGS) tf.summary.scalar('learning_rate', learning_rate) ############################# # Add train operation # ############################# variables_to_train = get_trainable_variables_within_scopes() train_op = add_train_op(optimizer, total_loss, global_step, var_list=variables_to_train) print_list("variables_to_train", variables_to_train) # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned. update_ops = get_update_ops_within_scopes() print_list("update_ops", update_ops) update_ops.append(train_op) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # add summary op summary_op = tf.summary.merge_all() print("HG: trainable_variables=", len(tf.trainable_variables())) print("HG: model_variables=", len(tf.model_variables())) print("HG: global_variables=", len(tf.global_variables())) # print_list('model_variables but not trainable variables', list(set(tf.model_variables()).difference(tf.trainable_variables()))) # print_list('global_variables but not model variables', list(set(tf.global_variables()).difference(tf.model_variables()))) # get train scopes for each kept_percentage block_names_dict = {} for block_name, block_kept_percentage in zip(block_names, kept_percentage): if block_kept_percentage not in block_names_dict: block_names_dict[block_kept_percentage] = [] block_names_dict[block_kept_percentage].append(block_name) #print_list("train_scopes", train_scopes) print('HG: block_names_dict:') pprint(block_names_dict) sess_config = tf.ConfigProto(intra_op_parallelism_threads=16, inter_op_parallelism_threads=16) with tf.Session(config=sess_config) as sess: ########################### # prepare for filewritter # ########################### train_writer = tf.summary.FileWriter(train_dir, sess.graph) # if restart the training or there is no checkpoint in the train_dir if (not FLAGS.continue_training) or ( not tf.train.latest_checkpoint(train_dir)): ################################################# # Restore pruned model variable values. # ################################################# all_variables_to_train = [] for block_kept_percentage, block_name in block_names_dict.items( ): print('HG: kept_percentage', block_kept_percentage) checkpoint_path = os.path.join( FLAGS.checkpoint_path, kept_percentages_dict[block_kept_percentage][0], 'train') # 'model.ckpt-'+str(FLAGS.local_train_steps)) variables_to_train = { re.sub( net_name_scope_pruned, net_name_scope_pruned + "_p" + str(block_kept_percentage), v.op.name): v for v in get_model_variables_with_block_names( net_name_scope_pruned, block_name) } print_list("restore pruned model variables", variables_to_train.values()) load_checkpoint(sess, checkpoint_path, var_list=variables_to_train) all_variables_to_train.extend(variables_to_train.values()) ################################################# # Restore orignal model variable values. # ################################################# variables_to_restore = { re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name): v for v in get_model_variables_within_scopes() if v not in set(all_variables_to_train) } print_list("restore original model variables", variables_to_restore.values()) load_checkpoint(sess, checkpoint_path, var_list=variables_to_restore) else: ########################################### ## Restore all variables from checkpoint ## ########################################### variables_to_restore = get_global_variables_within_scopes() load_checkpoint(sess, train_dir, var_list=variables_to_restore) ################################################# # init unitialized global variable. # ################################################# variables_to_init = get_global_variables_within_scopes( sess.run(tf.report_uninitialized_variables())) print_list("init unitialized variables", variables_to_init) sess.run(tf.variables_initializer(variables_to_init)) init_global_step_value = sess.run(global_step) print('initial global step: ', init_global_step_value) if init_global_step_value >= FLAGS.max_number_of_steps: print('Exit: init_global_step_value (%d) >= FLAGS.max_number_of_steps (%d)' \ %(init_global_step_value, FLAGS.max_number_of_steps)) return ########################### # Record CPU usage # ########################### mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log") os.system("mpstat -P ALL 1 > " + mpstat_output_filename + " 2>&1 &") ########################### # Kicks off the training. # ########################### coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) print('HG: # of threads=', len(threads)) duration = 0 duration_cnt = 0 train_time = 0 train_only_cnt = 0 print("start to train at:", datetime.now()) for i in range(init_global_step_value, FLAGS.max_number_of_steps + 1): #train_step = i+FLAGS.local_train_steps train_step = i # run optional meta data, or summary, while run train tensor if i > init_global_step_value: #if i < FLAGS.max_number_of_steps: # run metadata and train if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1: run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() loss_value = sess.run(train_tensor, options=run_options, run_metadata=run_metadata) train_writer.add_run_metadata(run_metadata, 'step%d-train' % i) # Create the Timeline object, and write it to a json file fetched_timeline = timeline.Timeline( run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format( ) with open( os.path.join(train_dir, 'timeline_' + str(i) + '.json'), 'w') as f: f.write(chrome_trace) # record summary and train elif i % FLAGS.summary_every_n_steps == 0: train_summary, loss_value = sess.run( [summary_op, train_tensor]) train_writer.add_summary(train_summary, train_step) # train only else: start_time = time.time() loss_value = sess.run(train_tensor) train_only_cnt += 1 train_time += time.time() - start_time duration_cnt += 1 duration += time.time() - start_time if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0: log_frequency = duration_cnt examples_per_sec = log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / log_frequency) summary = tf.Summary() summary.value.add(tag='examples_per_sec', simple_value=examples_per_sec) summary.value.add(tag='sec_per_batch', simple_value=sec_per_batch) train_writer.add_summary(summary, train_step) format_str = ( '%s: step %d, loss = %.3f (%.1f examples/sec; %.3f sec/batch)' ) print(format_str % (datetime.now(), i, loss_value, examples_per_sec, sec_per_batch)) duration = 0 duration_cnt = 0 info = format_str % (datetime.now(), i, loss_value, examples_per_sec, sec_per_batch) write_detailed_info(info) else: # run only total loss when i=0 train_summary, loss_value = sess.run( [summary_op, total_loss]) #loss_value = sess.run(total_loss) train_writer.add_summary(train_summary, train_step) format_str = ('%s: step %d, loss = %.3f') print(format_str % (datetime.now(), i, loss_value)) info = format_str % (datetime.now(), i, loss_value) write_detailed_info(info) # record the evaluation accuracy is_last_step = (i == FLAGS.max_number_of_steps) if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step: test_accuracy, run_metadata = evaluate_accuracy( sess, coord, test_dataset.num_samples, test_images, test_labels, test_images, test_labels, correct_prediction, FLAGS.test_batch_size, run_meta=False) summary = tf.Summary() summary.value.add(tag='accuracy', simple_value=test_accuracy) train_writer.add_summary(summary, train_step) info = ('%s: step %d, test_accuracy = %.6f') % ( datetime.now(), train_step, test_accuracy) print(info) write_detailed_info(info) ########################### # Save model parameters . # ########################### save_path = saver.save( sess, os.path.join(train_dir, 'model.ckpt-' + str(i))) print("HG: Model saved in file: %s" % save_path) coord.request_stop() coord.join(threads) total_time = time.time() - tic train_speed = train_time / train_only_cnt train_time = (FLAGS.max_number_of_steps) * train_speed info = "HG: training speed(sec/batch): %.6f\n" % (train_speed) info += "HG: training time(min): %.1f, total time(min): %.1f \n" % ( train_time / 60.0, total_time / 60.0) print(info) write_detailed_info(info)
def main(_): tic = time.time() print('tensorflow version:', tf.__version__) tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') # init print('HG: Train pruned blocks for all valid layers concurrently') block_names = valid_block_names net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint kept_percentages = sorted( [float(x) for x in FLAGS.kept_percentages.split(',')]) print_list('kept_percentages', kept_percentages) # prepare file system results_dir = os.path.join(FLAGS.train_dir, 'kp' + FLAGS.kept_percentages) train_dir = os.path.join(results_dir, 'train') if (not FLAGS.continue_training) or ( not tf.train.latest_checkpoint(train_dir)): print('Start a new training') prepare_file_system(train_dir) else: print('Continue training') def write_log_info(info): with open(os.path.join(FLAGS.train_dir, 'log.txt'), 'a') as f: f.write(info + '\n') def write_detailed_info(info): with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f: f.write(info + '\n') info = 'train_dir:' + train_dir log_info = info + '\n' write_detailed_info(info) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.train_dataset_name, FLAGS.dataset_dir) test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.test_dataset_name, FLAGS.dataset_dir) batch_queue = train_inputs(dataset, deploy_config, FLAGS) test_images, test_labels = test_inputs(test_dataset, deploy_config, FLAGS) images, labels = batch_queue.dequeue() ###################### # Select the network# ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay) _, end_points = network_fn(images, is_training=False) # for item in end_points.iteritems(): # print(item) # return network_fn_pruned = nets_factory.get_network_fn_pruned( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay) # ######################################### # # Configure the optimization procedure. # # ######################################### with tf.device(deploy_config.variables_device()): global_step = tf.Variable(0, trainable=False, name='global_step') with tf.device(deploy_config.optimizer_device()): learning_rate = configure_learning_rate(dataset.num_samples, global_step, FLAGS) optimizer = configure_optimizer(learning_rate, FLAGS) tf.summary.scalar('learning_rate', learning_rate) #################### # Define the model # #################### # each kept_percentage corresponds to a pruned network. train_tensors = [] total_losses = [] pruned_net_name_scopes = [] correct_predictions = [] prune_infos = [] for kept_percentage in kept_percentages: prune_info = kept_percentage_sequence_to_prune_info( kept_percentage, block_names) set_prune_info_inputs(prune_info, end_points) prune_infos.append(prune_info) # the pruned network scope net_name_scope_pruned = FLAGS.net_name_scope_pruned + '_p' + str( kept_percentage) pruned_net_name_scopes.append(net_name_scope_pruned) # generate the pruned network for training _, end_points_pruned = network_fn_pruned( images, prune_info=prune_info, is_training=True, is_local_train=True, reuse_variables=False, scope=net_name_scope_pruned) # generate the pruned network for testing logits, _ = network_fn_pruned(test_images, prune_info=prune_info, is_training=False, is_local_train=False, reuse_variables=True, scope=net_name_scope_pruned) # add correct prediction to the testing network correct_prediction = add_correct_prediction(logits, test_labels) correct_predictions.append(correct_prediction) ############################# # Specify the loss functions # ############################# for i, block_name in enumerate(block_names): print('HG: i=%d, block_name=%s' % (i, block_name)) # add l2 losses appendix = '_p' + str(kept_percentage) + '_' + str(i) collection_name = 'subgraph_losses' + appendix # print("HG: collection_name=", collection_name) outputs = end_points[block_name] outputs_pruned = end_points_pruned[block_name] l2_loss = add_l2_loss(outputs, outputs_pruned, add_to_collection=True, collection_name=collection_name) # get regularization loss regularization_losses = get_regularization_losses_with_block_names(net_name_scope_pruned, \ block_name, add_to_collection=True, collection_name=collection_name) print_list('regularization_losses', regularization_losses) # total loss and its summary total_loss = tf.add_n(tf.get_collection(collection_name), name='total_loss') for l in tf.get_collection(collection_name) + [total_loss]: tf.summary.scalar(l.op.name + appendix + '/summary', l) total_losses.append(total_loss) ############################# # Add train operation # ############################# variables_to_train = get_trainable_variables_with_block_names( net_name_scope_pruned, block_name) print_list("variables_to_train", variables_to_train) # add train_op if i == 0 and kept_percentage == kept_percentages[0]: global_step_tmp = global_step else: global_step_tmp = tf.Variable(0, trainable=False, name='global_step' + appendix) train_op = add_train_op(optimizer, total_loss, global_step_tmp, var_list=variables_to_train) # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned. update_ops = get_update_ops_with_block_names( net_name_scope_pruned, block_name) print_list("update_ops", update_ops) update_ops.append(train_op) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op' + appendix) train_tensors.append(train_tensor) # add summary op summary_op = tf.summary.merge_all() print("HG: trainable_variables=", len(tf.trainable_variables())) print("HG: model_variables=", len(tf.model_variables())) print("HG: global_variables=", len(tf.global_variables())) print_list( 'model_variables but not trainable variables', list( set(tf.model_variables()).difference( tf.trainable_variables()))) print_list( 'global_variables but not model variables', list(set(tf.global_variables()).difference(tf.model_variables()))) print( "HG: trainable_variables from " + net_name_scope_checkpoint + "=", len( get_trainable_variables_within_scopes( [net_name_scope_checkpoint + '/']))) print( "HG: trainable_variables from " + net_name_scope_pruned + "=", len( get_trainable_variables_within_scopes( [net_name_scope_pruned + '/']))) print( "HG: model_variables from " + net_name_scope_checkpoint + "=", len( get_model_variables_within_scopes( [net_name_scope_checkpoint + '/']))) print( "HG: model_variables from " + net_name_scope_pruned + "=", len( get_model_variables_within_scopes( [net_name_scope_pruned + '/']))) print( "HG: global_variables from " + net_name_scope_checkpoint + "=", len( get_global_variables_within_scopes( [net_name_scope_checkpoint + '/']))) print( "HG: global_variables from " + net_name_scope_pruned + "=", len( get_global_variables_within_scopes( [net_name_scope_pruned + '/']))) sess_config = tf.ConfigProto(intra_op_parallelism_threads=16, inter_op_parallelism_threads=16) with tf.Session(config=sess_config) as sess: ########################### # prepare for filewritter # ########################### train_writer = tf.summary.FileWriter(train_dir, sess.graph) if (not FLAGS.continue_training) or ( not tf.train.latest_checkpoint(train_dir)): ########################################### # Restore original model variable values. # ########################################### variables_to_restore = get_model_variables_within_scopes( [net_name_scope_checkpoint + '/']) print_list("restore model variables for original", variables_to_restore) load_checkpoint(sess, FLAGS.checkpoint_path, var_list=variables_to_restore) ################################################# # Init pruned networks with well-trained model # ################################################# for i in range(len(pruned_net_name_scopes)): net_name_scope_pruned = pruned_net_name_scopes[i] print('net_name_scope_pruned=', net_name_scope_pruned) ## init pruned variables . kept_percentage = kept_percentages[i] prune_info = prune_infos[i] variables_init_value = get_pruned_kernel_matrix( sess, prune_info, net_name_scope_checkpoint) reinit_scopes = [ re.sub(net_name_scope_checkpoint, net_name_scope_pruned, name) for name in variables_init_value.keys() ] variables_to_reinit = get_model_variables_within_scopes( reinit_scopes) print_list("Initialize pruned variables", variables_to_reinit) assign_ops = [] for v in variables_to_reinit: key = re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name) if key in variables_init_value: value = variables_init_value.get(key) # print(key, value) assign_ops.append( tf.assign(v, tf.convert_to_tensor(value), validate_shape=True)) # v.set_shape(value.shape) else: raise ValueError( "Key not in variables_init_value, key=", key) assign_op = tf.group(*assign_ops) sess.run(assign_op) ################################################# # Restore unchanged model variable. # ################################################# variables_to_restore = { re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name): v for v in get_model_variables_within_scopes( [net_name_scope_pruned + '/']) if v not in variables_to_reinit } print_list( "restore model variables for " + net_name_scope_pruned, variables_to_restore.values()) load_checkpoint(sess, FLAGS.checkpoint_path, var_list=variables_to_restore) else: # restore all variables from checkpoint variables_to_restore = get_global_variables_within_scopes() load_checkpoint(sess, train_dir, var_list=variables_to_restore) ################################################# # init unitialized global variable. # ################################################# uninitialized_variables = [ x.decode('utf-8') for x in sess.run(tf.report_uninitialized_variables()) ] print_list('uninitialized variables', uninitialized_variables) variables_to_init = [ v for v in tf.global_variables() if v.name.split(':')[0] in set(uninitialized_variables) ] #get_global_variables_within_scopes(uninitialized_variables) print_list("variables_to_init", variables_to_init) sess.run(tf.variables_initializer(variables_to_init)) init_global_step_value = sess.run(global_step) print('initial global step: ', init_global_step_value) if init_global_step_value >= FLAGS.max_number_of_steps: print('Exit: init_global_step_value (%d) >= FLAGS.max_number_of_steps (%d)' \ %(init_global_step_value, FLAGS.max_number_of_steps)) return ########################### # Record CPU usage # ########################### mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log") os.system("mpstat -P ALL 1 > " + mpstat_output_filename + " 2>&1 &") ########################### # Kicks off the training. # ########################### coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) print('HG: # of threads=', len(threads)) # saver for models if FLAGS.max_to_keep <= 0: max_to_keep = int(2 * FLAGS.max_number_of_steps / FLAGS.evaluate_every_n_steps) else: max_to_keep = FLAGS.max_to_keep saver = tf.train.Saver(max_to_keep=max_to_keep) train_time = 0 # the amount of time spending on sgd training only. duration = 0 # used to estimate the training speed train_only_cnt = 0 # used to calculate the true training time. duration_cnt = 0 print("start to train at:", datetime.now()) for i in range(init_global_step_value, FLAGS.max_number_of_steps + 1): # run optional meta data, or summary, while run train tensor if i > init_global_step_value: # FLAGS.max_number_of_steps: # run metadata if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1: run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() loss_values = sess.run(train_tensors, options=run_options, run_metadata=run_metadata) train_writer.add_run_metadata(run_metadata, 'step%d-train' % i) # Create the Timeline object, and write it to a json file fetched_timeline = timeline.Timeline( run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format( ) with open( os.path.join(train_dir, 'timeline_' + str(i) + '.json'), 'w') as f: f.write(chrome_trace) # record summary elif i % FLAGS.summary_every_n_steps == 0: results = sess.run([summary_op] + train_tensors) train_summary, loss_values = results[0], results[1:] train_writer.add_summary(train_summary, i) # print('HG: train with summary') # only run train op else: start_time = time.time() loss_values = sess.run(train_tensors) train_only_cnt += 1 duration_cnt += 1 train_time += time.time() - start_time duration += time.time() - start_time if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0: # record speed log_frequency = duration_cnt examples_per_sec = log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / log_frequency) summary = tf.Summary() summary.value.add(tag='examples_per_sec', simple_value=examples_per_sec) summary.value.add(tag='sec_per_batch', simple_value=sec_per_batch) train_writer.add_summary(summary, i) info = ( '%s: step %d, loss = %s (%.1f examples/sec; %.3f sec/batch)' ) % (datetime.now(), i, str(loss_values), examples_per_sec, sec_per_batch) print(info) duration = 0 duration_cnt = 0 write_detailed_info(info) else: # run only total loss when i=0 results = sess.run( [summary_op] + total_losses) #loss_value = sess.run(total_loss) train_summary, loss_values = results[0], results[1:] train_writer.add_summary(train_summary, i) format_str = ('%s: step %d, loss = %s') print(format_str % (datetime.now(), i, str(loss_values))) info = format_str % (datetime.now(), i, str(loss_values)) write_detailed_info(info) # record the evaluation accuracy is_last_step = (i == FLAGS.max_number_of_steps) if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step: # test accuracy; each kept_percentage corresponds to a pruned network, and thus an accuracy. test_accuracies = [] for p in range(len(kept_percentages)): kept_percentage = kept_percentages[p] appendix = '_p' + str(kept_percentage) correct_prediction = correct_predictions[p] # run_meta = (i==FLAGS.evaluate_every_n_steps)&&(p==0) test_accuracy, run_metadata = evaluate_accuracy( sess, coord, test_dataset.num_samples, test_images, test_labels, test_images, test_labels, correct_prediction, FLAGS.test_batch_size, run_meta=False) summary = tf.Summary() summary.value.add(tag='accuracy' + appendix, simple_value=test_accuracy) train_writer.add_summary(summary, i) test_accuracies.append( (kept_percentage, test_accuracy)) # if run_meta: # eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i) acc_str = '[' + ', '.join([ '(%s, %.6f)' % (str(kp), acc) for kp, acc in test_accuracies ]) + ']' info = ('%s: step %d, test_accuracy = %s') % ( datetime.now(), i, str(acc_str)) print(info) if i == 0 or is_last_step: # write_log_info(info) log_info += info + '\n' write_detailed_info(info) ########################### # Save model parameters . # ########################### # saver = tf.train.Saver() save_path = saver.save( sess, os.path.join(train_dir, 'model.ckpt-' + str(i))) print("HG: Model saved in file: %s" % save_path) coord.request_stop() coord.join(threads) total_time = time.time() - tic train_speed = train_time / train_only_cnt train_time = (FLAGS.max_number_of_steps) * train_speed info = "HG: training speed(sec/batch): %.6f\n" % (train_speed) info += "HG: training time(min): %.1f, total time(min): %.1f \n" % ( train_time / 60.0, total_time / 60.0) print(info) log_info += info write_log_info(log_info) write_detailed_info(info)
def main(_): tic = time.time() tf.logging.set_verbosity(tf.logging.INFO) # init net_name_scope_pruned = FLAGS.net_name_scope_pruned net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint indexed_prune_scopes_for_units = valid_indexed_prune_scopes_for_units kept_percentages = sorted(map(float, FLAGS.kept_percentages.split(','))) num_options = len(kept_percentages) num_units = len(indexed_prune_scopes_for_units) print('num_options=%d, num_blocks=%d' % (num_options, num_units)) print('HG: total number of configurations=%d' % (num_options**num_units)) # find the configurations to evaluate if FLAGS.configuration_type == 'sample': configs = get_sampled_configurations(num_units, num_options, FLAGS.total_num_configurations) elif FLAGS.configuration_type == 'special': configs = get_special_configurations(num_units, num_options) num_configurations = len(configs) #Getting MPI rank integer comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank >= num_configurations: print("ERROR: rank(%d) > num_configurations(%d)" % (rank, num_configurations)) return FLAGS.configuration_index = FLAGS.start_configuration_index + rank config = configs[FLAGS.configuration_index] if FLAGS.configuration_index >= num_configurations: print('configuration_index >= num_configurations', FLAGS.configuration_index, num_configurations) return print('HG: kept_percentages=%s, start_config_index=%d, num_configs=%d, rank=%d, config_index=%d' \ %(str(kept_percentages), FLAGS.start_configuration_index, num_configurations, rank, FLAGS.configuration_index)) # prepare for evaluate the number of parameters with the specific config combination = config indexed_prune_scopes, kept_percentage = config_to_indexed_prune_scopes( combination, indexed_prune_scopes_for_units, kept_percentages) prune_scopes = indexed_prune_scopes_to_prune_scopes( indexed_prune_scopes, net_name_scope_checkpoint) shorten_scopes = indexed_prune_scopes_to_shorten_scopes( indexed_prune_scopes, net_name_scope_checkpoint) reinit_scopes = [ re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v) for v in prune_scopes + shorten_scopes ] # prepare file system eval_dir = os.path.join(FLAGS.train_dir, "id" + str(FLAGS.configuration_index)) prepare_file_system(eval_dir) # functions to write logs # def write_log_info(info): # with open(os.path.join(FLAGS.train_dir, 'log.txt'), 'a') as f: # f.write(info+'\n') def write_detailed_info(info): with open(os.path.join(eval_dir, 'eval_details.txt'), 'a') as f: f.write(info + '\n') info = 'eval_dir:' + eval_dir + '\n' info += 'options:' + FLAGS.kept_percentages + '\n' info += 'combination: ' + str(combination) + '\n' info += 'indexed_prune_scopes: ' + str(indexed_prune_scopes) + '\n' info += 'kept_percentage: ' + str(kept_percentage) print(info) write_detailed_info(info) with tf.Graph().as_default(): deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) ###################### # Select the dataset # ###################### test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.test_dataset_name, FLAGS.dataset_dir) test_images, test_labels = test_inputs(test_dataset, deploy_config, FLAGS) ###################### # Select the network# ###################### network_fn_pruned = nets_factory.get_network_fn_pruned( FLAGS.model_name, num_classes=(test_dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay) #################### # Define the model # #################### prune_info = indexed_prune_scopes_to_prune_info( indexed_prune_scopes, kept_percentage) print('HG: prune_info:') pprint(prune_info) logits, _ = network_fn_pruned(test_images, prune_info=prune_info, is_training=False, is_local_train=False, reuse_variables=False, scope=net_name_scope_pruned) correct_prediction = add_correct_prediction(logits, test_labels) model_variables = get_model_variables_within_scopes() name_size_strs = [ x.op.name + '\t' + str(x.get_shape().as_list()) for x in model_variables ] write_detailed_info('\n'.join(name_size_strs)) total_time = time.time() - tic info = 'Evaluate network total_time(s)=%.3f \n' % (total_time) print(info) write_detailed_info(info)
def main(_): tic = time.time() print('tensorflow version:', tf.__version__) tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') # initialize constants net_name_scope_pruned = FLAGS.net_name_scope_pruned net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint kp_options = sorted([float(x) for x in FLAGS.kept_percentages.split(',')]) num_options = len(kp_options) pruned_layer_names = valid_layer_names[:-1] num_units = len(pruned_layer_names) print_list('kp_options', kp_options) print('HG: num_options=%d, num_units=%d' % (num_options, num_units)) print('HG: total number of configurations=%d' % (num_options**num_units)) # find the configurations to evaluate configs = get_sampled_configurations(num_units, num_options, FLAGS.total_num_configs) num_configs = len(configs) # print('HG: config_type=', FLAGS.config_type, ', num_configs=', num_configs) #Getting MPI rank integer # comm = MPI.COMM_WORLD # rank = comm.Get_rank() # use rank as offset rank = 0 config_id = FLAGS.start_config_id + rank print('HG: start_config_index=%d, rank=%d, config_index=%d' % (FLAGS.start_config_id, rank, config_id)) if config_id >= num_configs: print("ERROR: config_id(%d) >= num_configs(%d)" % (config_id, num_configs)) return # get the specific configuration config = configs[config_id] config = [kp_options[i] for i in config] if not FLAGS.last_conv_pruned: # if the last conv in a block is not pruned. reset the config for i in xrange(len(config)): if (i + 1) % FLAGS.block_size == 0: config[i] = 1.0 print('HG: selected config=', config) # prepare for training with the specific config prune_info = {} for i in xrange(len(config)): layer_name = pruned_layer_names[i] prune_info[layer_name] = {'kp': config[i]} # prepare file system if FLAGS.last_conv_pruned: foldername = 'last_conv_pruned' else: foldername = 'last_conv_unpruned' results_dir = os.path.join(FLAGS.train_dir, foldername, 'id' + str(config_id)) train_dir = os.path.join(results_dir, 'train') if not (FLAGS.continue_training and tf.train.latest_checkpoint(train_dir)): prune_scopes = [ layer_name_to_prune_scope(layer_name, net_name_scope_checkpoint) for layer_name in pruned_layer_names ] shorten_scopes = [ layer_name_to_shorten_scope(layer_name, net_name_scope_checkpoint) for layer_name in pruned_layer_names ] variables_init_value, _ = get_init_values_for_pruned_layers( prune_scopes, shorten_scopes, config) reinit_scopes = [ re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v) for v in prune_scopes + shorten_scopes ] prepare_file_system(train_dir) def write_detailed_info(info): with open(os.path.join(train_dir, 'train_details.txt'), 'a') as f: f.write(info + '\n') info = 'train_dir:' + train_dir + '\n' info += 'kp_options:' + str(kp_options) + '\n' info += 'configuration: ' + str(config) + '\n' # print(info) write_detailed_info(info) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.train_dataset_name, FLAGS.dataset_dir) test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.test_dataset_name, FLAGS.dataset_dir) batch_queue = train_inputs(dataset, deploy_config, FLAGS) test_images, test_labels = test_inputs(test_dataset, deploy_config, FLAGS) images, labels = batch_queue.dequeue() ###################### # Select the network# ###################### network_fn_pruned = nets_factory.get_network_fn_pruned( FLAGS.model_name, prune_info=prune_info, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay) print('HG: prune_info:') pprint(prune_info) #################### # Define the model # #################### logits_train, _ = network_fn_pruned(images, is_training=True, is_local_train=False, reuse_variables=False, scope=net_name_scope_pruned) logits_eval, _ = network_fn_pruned(test_images, is_training=False, is_local_train=False, reuse_variables=True, scope=net_name_scope_pruned) cross_entropy = add_cross_entropy(logits_train, labels) correct_prediction = add_correct_prediction(logits_eval, test_labels) ############################# # Specify the loss function # ############################# tf.add_to_collection('subgraph_losses', cross_entropy) # get regularization loss regularization_losses = get_regularization_losses_within_scopes() print_list('regularization_losses', regularization_losses) # total loss and its summary total_loss = tf.add_n(tf.get_collection('subgraph_losses'), name='total_loss') for l in tf.get_collection('subgraph_losses') + [total_loss]: tf.summary.scalar(l.op.name + '/summary', l) ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.variables_device()): global_step = tf.Variable(0, trainable=False, name='global_step') with tf.device(deploy_config.optimizer_device()): learning_rate = configure_learning_rate(dataset.num_samples, global_step, FLAGS) optimizer = configure_optimizer(learning_rate, FLAGS) tf.summary.scalar('learning_rate', learning_rate) ############################# # Add train operation # ############################# variables_to_train = get_trainable_variables_within_scopes() train_op = add_train_op(optimizer, total_loss, global_step, var_list=variables_to_train) print_list("variables_to_train", variables_to_train) # Gather update_ops: the updates for the batch_norm variables created by network_fn_pruned. update_ops = get_update_ops_within_scopes() print_list("update_ops", update_ops) # add train_tensor update_ops.append(train_op) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # add summary op summary_op = tf.summary.merge_all() print("HG: trainable_variables=", len(tf.trainable_variables())) print("HG: model_variables=", len(tf.model_variables())) print("HG: global_variables=", len(tf.global_variables())) sess_config = tf.ConfigProto(intra_op_parallelism_threads=16, inter_op_parallelism_threads=16) with tf.Session(config=sess_config) as sess: ########################### # Prepare for filewriter. # ########################### train_writer = tf.summary.FileWriter(train_dir, sess.graph) # if restart the training and there is a checkpoint in the train_dir if FLAGS.continue_training and tf.train.latest_checkpoint( train_dir): ########################################### ## Restore all variables from checkpoint ## ########################################### variables_to_restore = get_global_variables_within_scopes() load_checkpoint(sess, train_dir, var_list=variables_to_restore) else: ######################################### # Reinit pruned model variable # ######################################### variables_to_reinit = get_model_variables_within_scopes( reinit_scopes) print_list("Initialize pruned variables", variables_to_reinit) assign_ops = [] for v in variables_to_reinit: key = re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name) if key in variables_init_value: value = variables_init_value.get(key) # print(key, value) assign_ops.append( tf.assign(v, tf.convert_to_tensor(value), validate_shape=True)) # v.set_shape(value.shape) else: raise ValueError( "Key not in variables_init_value, key=", key) assign_op = tf.group(*assign_ops) sess.run(assign_op) ################################################# # Restore unchanged model variable. # ################################################# variables_to_restore = { re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name): v for v in get_model_variables_within_scopes() if v not in variables_to_reinit } print_list("restore model variables", variables_to_restore.values()) load_checkpoint(sess, FLAGS.checkpoint_path, var_list=variables_to_restore) ################################################# # init unitialized global variable. # ################################################# variables_to_init = get_global_variables_within_scopes( sess.run(tf.report_uninitialized_variables())) print_list("init unitialized variables", variables_to_init) sess.run(tf.variables_initializer(variables_to_init)) init_global_step_value = sess.run(global_step) print('initial global step: ', init_global_step_value) if init_global_step_value >= FLAGS.max_number_of_steps: print('Exit: init_global_step_value (%d) >= FLAG.max_number_of_steps (%d)' \ %(init_global_step_value, FLAGS.max_number_of_steps)) return ########################### # Record CPU usage # ########################### # mpstat_output_filename = os.path.join(train_dir, "cpu-usage.log") # os.system("mpstat -P ALL 1 > " + mpstat_output_filename + " 2>&1 &") ########################### # Kicks off the training. # ########################### coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) print('HG: # of threads=', len(threads)) duration = 0 duration_cnt = 0 train_time = 0 train_only_cnt = 0 print("start to train at:", datetime.now()) for i in range(init_global_step_value, FLAGS.max_number_of_steps + 1): # run optional meta data, or summary, while run train tensor #if i < FLAGS.max_number_of_steps: if i > init_global_step_value: # train while run metadata if i % FLAGS.runmeta_every_n_steps == FLAGS.runmeta_every_n_steps - 1: run_options = tf.RunOptions( trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() loss_value = sess.run(train_tensor, options=run_options, run_metadata=run_metadata) train_writer.add_run_metadata(run_metadata, 'step%d-train' % i) # Create the Timeline object, and write it to a json file fetched_timeline = timeline.Timeline( run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format( ) with open( os.path.join(train_dir, 'timeline_' + str(i) + '.json'), 'w') as f: f.write(chrome_trace) # train while record summary elif i % FLAGS.summary_every_n_steps == 0: train_summary, loss_value = sess.run( [summary_op, train_tensor]) train_writer.add_summary(train_summary, i) # train only else: start_time = time.time() loss_value = sess.run(train_tensor) train_only_cnt += 1 train_time += time.time() - start_time duration_cnt += 1 duration += time.time() - start_time # log loss information if i % FLAGS.log_every_n_steps == 0 and duration_cnt > 0: log_frequency = duration_cnt examples_per_sec = log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / log_frequency) summary = tf.Summary() summary.value.add(tag='examples_per_sec', simple_value=examples_per_sec) summary.value.add(tag='sec_per_batch', simple_value=sec_per_batch) train_writer.add_summary(summary, i) format_str = ( '%s: step %d, loss = %s (%.1f examples/sec; %.3f sec/batch)' ) print(format_str % (datetime.now(), i, str(loss_value), examples_per_sec, sec_per_batch)) duration = 0 duration_cnt = 0 info = format_str % (datetime.now(), i, str(loss_value), examples_per_sec, sec_per_batch) write_detailed_info(info) else: # run only total loss when i=0 train_summary, loss_value = sess.run( [summary_op, total_loss]) #loss_value = sess.run(total_loss) train_writer.add_summary(train_summary, i) format_str = ('%s: step %d, loss = %s') print(format_str % (datetime.now(), i, str(loss_value))) info = format_str % (datetime.now(), i, str(loss_value)) write_detailed_info(info) # record the evaluation accuracy is_last_step = (i == FLAGS.max_number_of_steps) if i % FLAGS.evaluate_every_n_steps == 0 or is_last_step: #run_meta = (i==FLAGS.evaluate_every_n_steps) test_accuracy, run_metadata = evaluate_accuracy( sess, coord, test_dataset.num_samples, test_images, test_labels, test_images, test_labels, correct_prediction, FLAGS.test_batch_size, run_meta=False) summary = tf.Summary() summary.value.add(tag='accuracy', simple_value=test_accuracy) train_writer.add_summary(summary, i) #if run_meta: # eval_writer.add_run_metadata(run_metadata, 'step%d-eval' % i) info = ('%s: step %d, test_accuracy = %.6f') % ( datetime.now(), i, test_accuracy) print(info) write_detailed_info(info) ########################### # Save model parameters . # ########################### #saver = tf.train.Saver(var_list=get_model_variables_within_scopes([net_name_scope_pruned+'/'])) save_path = saver.save( sess, os.path.join(train_dir, 'model.ckpt-' + str(i))) print("HG: Model saved in file: %s" % save_path) coord.request_stop() coord.join(threads) total_time = time.time() - tic train_speed = train_time * 1.0 / train_only_cnt train_time = train_speed * ( FLAGS.max_number_of_steps ) # - init_global_step_value) #/train_only_cnt info = "HG: training speed(sec/batch): %.6f\n" % (train_speed) info += "HG: training time(min): %.1f, total time(min): %.1f" % ( train_time / 60.0, total_time / 60.0) print(info) write_detailed_info(info)
def main(_): tic = time.time() print('tensorflow version:', tf.__version__) tf.logging.set_verbosity(tf.logging.INFO) if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') # init net_name_scope_pruned = FLAGS.net_name_scope_pruned net_name_scope_checkpoint = FLAGS.net_name_scope_checkpoint indexed_prune_scopes_for_units = valid_indexed_prune_scopes_for_units kept_percentage = FLAGS.kept_percentage # set the configuration: should be a 16-length vector config = [1.0] * len(indexed_prune_scopes_for_units) config[FLAGS.block_id] = kept_percentage print("config:", config) # prepare for training with the specific config indexed_prune_scopes = indexed_prune_scopes_for_units[FLAGS.block_id] prune_info = indexed_prune_scopes_to_prune_info(indexed_prune_scopes, kept_percentage) print("prune_info:", prune_info) # prepare file system results_dir = os.path.join( FLAGS.train_dir, 'id' + str(FLAGS.block_id)) #+'_'+str(FLAGS.max_number_of_steps)) train_dir = os.path.join(results_dir, 'kp' + str(kept_percentage)) prune_scopes = indexed_prune_scopes_to_prune_scopes( indexed_prune_scopes, net_name_scope_checkpoint) shorten_scopes = indexed_prune_scopes_to_shorten_scopes( indexed_prune_scopes, net_name_scope_checkpoint) variables_init_value = get_init_values_for_pruned_layers( prune_scopes, shorten_scopes, kept_percentage) reinit_scopes = [ re.sub(net_name_scope_checkpoint, net_name_scope_pruned, v) for v in prune_scopes + shorten_scopes ] prepare_file_system(train_dir) def write_detailed_info(info): with open(os.path.join(train_dir, 'eval_details.txt'), 'a') as f: f.write(info + '\n') info = 'train_dir:' + train_dir + '\n' info += 'block_id:' + str(FLAGS.block_id) + '\n' info += 'configuration: ' + str(config) + '\n' info += 'indexed_prune_scopes: ' + str(indexed_prune_scopes) + '\n' info += 'kept_percentage: ' + str(kept_percentage) print(info) write_detailed_info(info) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) ###################### # Select the dataset # ###################### # dataset = dataset_factory.get_dataset( # FLAGS.dataset_name, FLAGS.train_dataset_name, FLAGS.dataset_dir) test_dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.test_dataset_name, FLAGS.dataset_dir) # batch_queue = train_inputs(dataset, deploy_config, FLAGS) test_images, test_labels = test_inputs(test_dataset, deploy_config, FLAGS) # images, labels = batch_queue.dequeue() ###################### # Select the network# ###################### network_fn_pruned = nets_factory.get_network_fn_pruned( FLAGS.model_name, prune_info=prune_info, num_classes=(test_dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay) print('HG: prune_info:') pprint(prune_info) #################### # Define the model # #################### # logits_train, _ = network_fn_pruned(images, is_training=True, is_local_train=False, reuse_variables=False, scope = net_name_scope_pruned) logits_eval, _ = network_fn_pruned(test_images, is_training=False, is_local_train=False, reuse_variables=False, scope=net_name_scope_pruned) correct_prediction = add_correct_prediction(logits_eval, test_labels) print("HG: trainable_variables=", len(tf.trainable_variables())) print("HG: model_variables=", len(tf.model_variables())) print("HG: global_variables=", len(tf.global_variables())) sess_config = tf.ConfigProto(intra_op_parallelism_threads=16, inter_op_parallelism_threads=16) with tf.Session(config=sess_config) as sess: ########################### # Prepare for filewriter. # ########################### # train_writer = tf.summary.FileWriter(train_dir, sess.graph) ######################################### # Reinit pruned model variable # ######################################### variables_to_reinit = get_model_variables_within_scopes( reinit_scopes) print_list("Initialize pruned variables", variables_to_reinit) assign_ops = [] for v in variables_to_reinit: key = re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name) if key in variables_init_value: value = variables_init_value.get(key) # print(key, value) assign_ops.append( tf.assign(v, tf.convert_to_tensor(value), validate_shape=True)) # v.set_shape(value.shape) else: raise ValueError("Key not in variables_init_value, key=", key) assign_op = tf.group(*assign_ops) sess.run(assign_op) ################################################# # Restore unchanged model variable. # ################################################# variables_to_restore = { re.sub(net_name_scope_pruned, net_name_scope_checkpoint, v.op.name): v for v in get_model_variables_within_scopes() if v not in variables_to_reinit } print_list("restore model variables", variables_to_restore.values()) load_checkpoint(sess, FLAGS.checkpoint_path, var_list=variables_to_restore) ########################### # Kicks off the training. # ########################### coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # saver = tf.train.Saver(max_to_keep=FLAGS.max_to_keep) print('HG: # of threads=', len(threads)) eval_time = -1 * time.time() test_accuracy, run_metadata = evaluate_accuracy( sess, coord, test_dataset.num_samples, test_images, test_labels, test_images, test_labels, correct_prediction, FLAGS.test_batch_size, run_meta=False) eval_time += time.time() info = ('%s: test_accuracy = %.6f') % (datetime.now(), test_accuracy) print(info) write_detailed_info(info) coord.request_stop() coord.join(threads) total_time = time.time() - tic info = "HG: training time(min): %.1f, total time(min): %.1f" % ( eval_time / 60.0, total_time / 60.0) print(info) write_detailed_info(info)