def testWeightSpecificSparsity(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with variable_scope.variable_scope("layer1"): w1 = variables.Variable(math_ops.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w1) with variable_scope.variable_scope("layer2"): w2 = variables.Variable(math_ops.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w2) p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = state_ops.assign_add(self.global_step, 1) with self.test_session() as session: variables.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllEqual(session.run(pruning.get_weight_sparsity()), [0.5, 0.75])
def testConditionalMaskUpdate(self): param_list = [ "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) weights = variables.Variable(math_ops.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) sparsity = variables.Variable(0.00, name="sparsity") # Set up pruning p = pruning.Pruning(pruning_hparams, sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.conditional_mask_update_op() sparsity_val = math_ops.linspace(0.0, 0.9, 10) increment_global_step = state_ops.assign_add(self.global_step, 1) non_zero_count = [] with self.test_session() as session: variables.global_variables_initializer().run() for i in range(10): session.run(state_ops.assign(sparsity, sparsity_val[i])) session.run(mask_update_op) session.run(increment_global_step) non_zero_count.append(np.count_nonzero(masked_weights.eval())) # Weights pruned at steps 0,2,4,and,6 expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] self.assertAllEqual(expected_non_zero_count, non_zero_count)
def testPerLayerBlockSparsity(self): param_list = [ "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]", "block_pooling_function=AVG", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with variable_scope.variable_scope("layer1"): w1 = constant_op.constant([[-0.1, 0.1], [-0.2, 0.2]], name="weights") pruning.apply_mask(w1) with variable_scope.variable_scope("layer2"): w2 = constant_op.constant( [[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]], name="weights") pruning.apply_mask(w2) sparsity = variables.VariableV1(0.5, name="sparsity") p = pruning.Pruning(pruning_hparams, sparsity=sparsity) mask_update_op = p.mask_update_op() with self.cached_session() as session: variables.global_variables_initializer().run() session.run(mask_update_op) mask1_eval = session.run(pruning.get_masks()[0]) mask2_eval = session.run(pruning.get_masks()[1]) self.assertAllEqual(session.run(pruning.get_weight_sparsity()), [0.5, 0.5]) self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]]) self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
def _setup_graph(self): ''' ''' default_dict = { 'name': 'model_pruining', 'begin_pruning_step': 0, 'end_pruning_step': 34400, 'target_sparsity': 0.31, 'pruning_frequency': 344, 'sparsity_function_begin_step': 0, 'sparsity_function_end_step': 34400, 'sparsity_function_exponent': 2, } for k, v in self.param_dict.items(): if k in default_dict: default_dict[k] = v param_list = ['{}={}'.format(k, v) for k, v in default_dict.items()] # param_list = [ # "name=cifar10_pruning", # "begin_pruning_step=1000", # "end_pruning_step=20000", # "target_sparsity=0.9", # "sparsity_function_begin_step=1000", # "sparsity_function_end_step=20000" # ] PRUNE_HPARAMS = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(PRUNE_HPARAMS) self.p = pruning.Pruning(pruning_hparams, global_step=get_global_step_var()) self.p.add_pruning_summaries() self.mask_update_op = self.p.conditional_mask_update_op()
def testConditionalMaskUpdate(self): param_list = [ "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) weights = variables.Variable( math_ops.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) sparsity = variables.Variable(0.00, name="sparsity") # Set up pruning p = pruning.Pruning(pruning_hparams, sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.conditional_mask_update_op() sparsity_val = math_ops.linspace(0.0, 0.9, 10) increment_global_step = state_ops.assign_add(self.global_step, 1) non_zero_count = [] with self.test_session() as session: variables.global_variables_initializer().run() for i in range(10): session.run(state_ops.assign(sparsity, sparsity_val[i])) session.run(mask_update_op) session.run(increment_global_step) non_zero_count.append(np.count_nonzero(masked_weights.eval())) # Weights pruned at steps 0,2,4,and,6 expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] self.assertAllEqual(expected_non_zero_count, non_zero_count)
def testWeightSpecificSparsity(self): param_list = [ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100", "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]", "threshold_decay=0.0" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) with variable_scope.variable_scope("layer1"): w1 = variables.Variable( math_ops.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w1) with variable_scope.variable_scope("layer2"): w2 = variables.Variable( math_ops.linspace(1.0, 100.0, 100), name="weights") _ = pruning.apply_mask(w2) p = pruning.Pruning(pruning_hparams) mask_update_op = p.conditional_mask_update_op() increment_global_step = state_ops.assign_add(self.global_step, 1) with self.cached_session() as session: variables.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) session.run(increment_global_step) self.assertAllEqual( session.run(pruning.get_weight_sparsity()), [0.5, 0.75])
def __init__(self, input_size, output_size, model_path: str, momentum=0.9, reg_str=0.0005, scope='ConvNet', pruning_start=int(10e4), pruning_end=int(10e5), pruning_freq=int(10), sparsity_start=0, sparsity_end=int(10e5), target_sparsity=0.0, dropout=0.5, initial_sparsity=0, wd=0.0): super(ConvNet, self).__init__(input_size=input_size, output_size=output_size, model_path=model_path) self.scope = scope self.momentum = momentum self.reg_str = reg_str self.dropout = dropout self.logger = get_logger(scope) self.wd = wd self.logger.info("creating graph...") with self.graph.as_default(): self.global_step = tf.Variable(0, trainable=False) self._build_placeholders() self.logits = self._build_model() self.weights_matrices = pruning.get_masked_weights() self.sparsity = pruning.get_weight_sparsity() self.loss = self._loss() self.train_op = self._optimizer() self._create_metrics() self.saver = tf.train.Saver(var_list=tf.global_variables()) self.hparams = pruning.get_pruning_hparams()\ .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},' ' sparsity_function_begin_step={},sparsity_function_end_step={},' 'pruning_frequency={},initial_sparsity={},' ' sparsity_function_exponent={}'.format(scope, pruning_start, pruning_end, target_sparsity, sparsity_start, sparsity_end, pruning_freq, initial_sparsity, 3)) # note that the global step plays an important part in the pruning mechanism, # the higher the global step the closer the sparsity is to sparsity end self.pruning_obj = pruning.Pruning(self.hparams, global_step=self.global_step) self.mask_update_op = self.pruning_obj.conditional_mask_update_op() # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP self.init_variables( tf.global_variables()) # initialize variables in graph
def _prune_model(self, session): pruning_hparams = pruning.get_pruning_hparams().parse(self.pruning_spec) p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity) self.mask_update_op = p.conditional_mask_update_op() variables.global_variables_initializer().run() for _ in range(20): session.run(self.mask_update_op) session.run(self.increment_global_step)
def setUp(self): super(PruningHParamsTest, self).setUp() # Add global step variable to the graph self.global_step = training_util.get_or_create_global_step() # Add sparsity self.sparsity = variables.Variable(0.5, name="sparsity") # Parse hparams self.pruning_hparams = pruning.get_pruning_hparams().parse( self.TEST_HPARAMS)
def _prune_model(self, session): pruning_hparams = pruning.get_pruning_hparams().parse( self.pruning_spec) p = pruning.Pruning(pruning_hparams, sparsity=self.sparsity) self.mask_update_op = p.conditional_mask_update_op() variables.global_variables_initializer().run() for _ in range(20): session.run(self.mask_update_op) session.run(self.increment_global_step)
def __init__(self, model, data_handle, hyperparams): self.model = model self.data_handle = data_handle self.hyperparams = hyperparams # get defined tensor self.X = self.model.X self.Y = self.model.Y self.result = self.model.result self.train = self.model.Utils.is_train self.update = self.model.Utils.tensor_updated self.learning_rate = tf.placeholder(tf.float32) self.global_step = tf.Variable(0, dtype = tf.int32, trainable = False) self.weights_decay = self.hyperparams['weights_decay'] self.global_step_update = tf.assign_add(self.global_step, tf.constant(2, dtype = tf.int32)) # optimizer self.cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = self.Y, logits = self.result)) self.l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables()]) self.loss = self.l2_loss * self.weights_decay + self.cross_entropy # train_step = tf.train.AdamOptimizer(learning_rate = 0.001).minimize(loss) self.train_step = tf.train.MomentumOptimizer(self.learning_rate, 0.9, use_nesterov = True).minimize(self.loss) self.top1 = tf.equal(tf.argmax(self.result, 1), tf.argmax(self.Y, 1)) self.top1_acc = tf.reduce_mean(tf.cast(self.top1, "float")) self.top5 = tf.nn.in_top_k(predictions = self.result, targets = tf.argmax(self.Y, 1), k = 5) self.top5_acc = tf.reduce_mean(tf.cast(self.top5, "float")) # prune if self.hyperparams['enable_prune']: pruning_hparams = pruning.get_pruning_hparams() pruning_hparams.begin_pruning_step = self.hyperparams['begin_pruning_step'] pruning_hparams.end_pruning_step = self.hyperparams['end_pruning_step'] pruning_hparams.pruning_frequency = self.hyperparams['pruning_frequency'] pruning_hparams.target_sparsity = self.hyperparams['target_sparsity'] p = pruning.Pruning(pruning_hparams, global_step = self.global_step) self.prune_op = p.conditional_mask_update_op() # log log_prefix = "log" + "_quant_{}".format(self.hyperparams['quant_bits']) + "_prune_{}".format(str(self.hyperparams["enable_prune"])) + "/" if not os.path.exists(log_prefix): os.mkdir(log_prefix) self.fd = open(log_prefix + self.hyperparams['model_name'], "a") print("model_name = {}, quant_bits = {}, enable_prune = {}".format(self.hyperparams['model_name'], self.hyperparams['quant_bits'], self.hyperparams['target_sparsity']), file = self.fd) print(time.asctime(time.localtime(time.time())) + " train started", file = self.fd) # init_variable # config = tf.ConfigProto() # config.gpu_options.allow_growth = True # config.gpu_options.per_process_gpu_memory_fraction = 0.6 self.sess = tf.Session() self.sess.run(tf.global_variables_initializer())
def __init__(self, actor_input_dim, actor_output_dim, model_path, redundancy=None, last_measure=10e4, tau=0.01): super(StudentActor, self).__init__(model_path=model_path) self.actor_input_dim = (None, actor_input_dim) self.actor_output_dim = (None, actor_output_dim) self.tau = tau self.redundancy = redundancy self.last_measure = last_measure with self.graph.as_default(): self.actor_global_step = tf.Variable(0, trainable=False) self._build_placeholders() self.actor_logits = self._build_actor() # self.gumbel_dist = self._build_gumbel(self.actor_logits) self.loss = self._build_loss() self.actor_parameters = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor') self.actor_pruned_weight_matrices = pruning.get_masked_weights() self.actor_train_op = self._build_actor_train_op() self.actor_saver = tf.train.Saver(var_list=self.actor_parameters, max_to_keep=100) self.init_variables(tf.global_variables()) self.sparsity = pruning.get_weight_sparsity() self.hparams = pruning.get_pruning_hparams() \ .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},' ' sparsity_function_begin_step={},sparsity_function_end_step={},' 'pruning_frequency={},initial_sparsity={},' ' sparsity_function_exponent={}'.format('Actor', cfg.pruning_start, cfg.pruning_end, cfg.target_sparsity, cfg.sparsity_start, cfg.sparsity_end, cfg.pruning_freq, cfg.initial_sparsity, 3)) # note that the global step plays an important part in the pruning mechanism, # the higher the global step the closer the sparsity is to sparsity end self.pruning_obj = pruning.Pruning( self.hparams, global_step=self.actor_global_step) self.mask_update_op = self.pruning_obj.conditional_mask_update_op() # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP self.init_variables( tf.global_variables()) # initialize variables in graph
def _blockMasking(self, hparams, weights, expected_mask): threshold = variables.Variable(0.0, name="threshold") sparsity = variables.Variable(0.5, name="sparsity") test_spec = ",".join(hparams) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) # Set up pruning p = pruning.Pruning(pruning_hparams, sparsity=sparsity) with self.test_session(): variables.global_variables_initializer().run() _, new_mask = p._maybe_update_block_mask(weights, threshold) # Check if the mask is the same size as the weights self.assertAllEqual(new_mask.get_shape(), weights.get_shape()) mask_val = new_mask.eval() self.assertAllEqual(mask_val, expected_mask)
def set_prune_params(s): # Get, Print, and Edit Pruning Hyperparameters pruning_hparams = pruning.get_pruning_hparams() print("Pruning Hyperparameters:", pruning_hparams) # Change hyperparameters to meet our needs pruning_hparams.begin_pruning_step = 0 pruning_hparams.end_pruning_step = 250 pruning_hparams.pruning_frequency = 1 pruning_hparams.sparsity_function_end_step = 250 pruning_hparams.target_sparsity = s # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam p = pruning.Pruning(pruning_hparams, global_step=global_step) prune_op = p.conditional_mask_update_op() return prune_op
def _blockMasking(self, hparams, weights, expected_mask): threshold = variables.Variable(0.0, name="threshold") sparsity = variables.Variable(0.51, name="sparsity") test_spec = ",".join(hparams) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) # Set up pruning p = pruning.Pruning(pruning_hparams, sparsity=sparsity) with self.test_session(): variables.global_variables_initializer().run() _, new_mask = p._maybe_update_block_mask(weights, threshold) # Check if the mask is the same size as the weights self.assertAllEqual(new_mask.get_shape(), weights.get_shape()) mask_val = new_mask.eval() self.assertAllEqual(mask_val, expected_mask)
def pruning_params(global_step, begin_step=0, end_step=-1, pruning_freq=10, sparsity_function=2000, target_sparsity=.50, sparsity_exponent=1.0): """ Creates the pruning op :param global_step: the global step, needed for pruning :param begin_step: the global step at which to begin pruning :param end_step: the global step at which to end pruning :param pruning_freq: the frequency of global step for when to prune :param sparsity_function: the global step used as the end point for the gradual sparsity function :param target_sparsity: the target sparsity :param sparsity_exponent: the exponent for the sparsity function :return: Pruning op """ pruning_hparams = pruning.get_pruning_hparams() pruning_hparams.begin_pruning_step = begin_step pruning_hparams.end_pruning_step = end_step pruning_hparams.pruning_frequency = pruning_freq pruning_hparams.sparsity_function_end_step = sparsity_function pruning_hparams.target_sparsity = target_sparsity pruning_hparams.sparsity_function_exponent = sparsity_exponent p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=target_sparsity) p_op = p.conditional_mask_update_op() p.add_pruning_summaries() return p_op
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu, prune_config_flag): """Creates an optimizer training op.""" global_step = tf.train.get_or_create_global_step() learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32) # Implements linear decay of the learning rate. learning_rate = tf.train.polynomial_decay( learning_rate, global_step, num_train_steps, end_learning_rate=0.0, power=1.0, cycle=False) # Implements linear warmup. I.e., if global_step < num_warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. if num_warmup_steps: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = init_lr * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ( (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) # It is recommended that you use this optimizer for fine tuning, since this # is how the model was trained (note that the Adam m/v variables are NOT # loaded from init_checkpoint.) optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) if use_tpu: optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) # memory_saving_gradients.DEBUG_LOGGING = True tvars = tf.trainable_variables() if os.getenv('DISABLE_GRAD_CHECKPOINT'): grads = tf.gradients(loss, tvars) else: grads = memory_saving_gradients.gradients(loss, tvars, checkpoints='memory') # This is how the model was pre-trained. (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0) train_op = optimizer.apply_gradients( zip(grads, tvars), global_step=global_step) # Pruning mask update ops if prune_config_flag: tf.logging.info(f'Pruning with configs {prune_config_flag}') prune_config = get_pruning_hparams().parse(prune_config_flag) prune = Pruning(prune_config, global_step=global_step) mask_update_op = prune.conditional_mask_update_op() prune.add_pruning_summaries() else: tf.logging.info('No pruning config provided, skipping pruning') mask_update_op = tf.no_op() # Normally the global step update is done inside of `apply_gradients`. # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use # a different optimizer, you should probably take this line out. new_global_step = global_step + 1 train_op = tf.group(train_op, mask_update_op, [global_step.assign(new_global_step)]) return train_op
def train(): is_training = True # data pipeline imgs, true_boxes = gen_data_batch(re.sub(r'examples/', '', cfg.data_path), cfg.batch_size * cfg.train.num_gpus) imgs_split = tf.split(imgs, cfg.train.num_gpus) true_boxes_split = tf.split(true_boxes, cfg.train.num_gpus) global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0.), trainable=False) lr = tf.train.piecewise_constant(global_step, cfg.train.lr_steps, cfg.train.learning_rate) optimizer = tf.train.AdamOptimizer(learning_rate=lr) # Calculate the gradients for each model tower. tower_grads = [] summaries_buf = [] summaries = set() with tf.variable_scope(tf.get_variable_scope()): for i in range(cfg.train.num_gpus): with tf.device('/gpu:%d' % i): with tf.name_scope('%s_%d' % (cfg.train.tower, i)) as scope: model = PDetNet(imgs_split[i], true_boxes_split[i], is_training) loss = model.compute_loss() tf.get_variable_scope().reuse_variables() grads_and_vars = optimizer.compute_gradients(loss) # gradients_norm = summaries_gradients_norm(grads_and_vars) gradients_hist = summaries_gradients_hist(grads_and_vars) #summaries_buf.append(gradients_norm) summaries_buf.append(gradients_hist) ##sum_set = set() ##sum_set.add(tf.summary.scalar("loss", loss)) ##summaries_buf.append(sum_set) summaries_buf.append({tf.summary.scalar("loss", loss)}) # tower_grads.append(grads_and_vars) if i == 0: current_loss = loss update_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS) vars_det = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, scope="PDetNet") grads = average_gradients(tower_grads) with tf.control_dependencies(update_op): #train_op = optimizer.minimize(loss, global_step=global_step, var_list=vars_det) apply_gradient_op = optimizer.apply_gradients(grads, global_step=global_step) train_op = tf.group(apply_gradient_op, *update_op) # GPU config config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) config.gpu_options.allow_growth = True sess = tf.Session(config=config) ##pruning add by lzlu # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse( cfg.prune.pruning_hparams) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) # Use the pruning_obj to add ops to the training graph to update the masks # The conditional_mask_update_op will update the masks only when the # training step is in [begin_pruning_step, end_pruning_step] specified in # the pruning spec proto mask_update_op = pruning_obj.conditional_mask_update_op() # Use the pruning_obj to add summaries to the graph to track the sparsity # of each of the layers pruning_summaries = pruning_obj.add_pruning_summaries() summaries |= pruning_summaries for summ in summaries_buf: summaries |= summ summaries.add(tf.summary.scalar('lr', lr)) summary_op = tf.summary.merge(list(summaries), name='summary_op') if cfg.summary.summary_allowed: summary_writer = tf.summary.FileWriter( logdir=cfg.summary.logs_path, graph=sess.graph, flush_secs=cfg.summary.summary_secs) # Create a saver saver = tf.train.Saver() ckpt_dir = re.sub(r'examples/', '', cfg.ckpt_path_608) if cfg.train.fine_tune == 0: # init sess.run(tf.global_variables_initializer()) else: saver.restore(sess, cfg.train.rstd_path) # running for i in range(0, cfg.train.max_batches): _, loss_, gstep, sval, _ = sess.run( [train_op, current_loss, global_step, summary_op, mask_update_op]) if (i % 100 == 0): print(i, ': ', loss_) if i % 1000 == 0 and i < 10000: saver.save(sess, ckpt_dir + str(i) + '_plate.ckpt', global_step=global_step, write_meta_graph=False) if i % 10000 == 0: saver.save(sess, ckpt_dir + str(i) + '_plate.ckpt', global_step=global_step, write_meta_graph=False) if cfg.summary.summary_allowed and gstep % cfg.summary.summ_steps == 0: summary_writer.add_summary(sval, global_step=gstep)
def train_with_pruning(): tf.compat.v1.reset_default_graph() # Inference network = Network(NUM_CLASSES) inputs = tf.compat.v1.placeholder(tf.float32, [None, INPUT_SIZE, INPUT_SIZE, INPUT_CHANNEL], 'inputs') logits = network.pruning_inference(inputs) # loss & accuracy labels = tf.compat.v1.placeholder(tf.int64, [None, ], 'labels') loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)) prediction = tf.argmax(tf.nn.softmax(logits), axis=1) acc = tf.reduce_mean(tf.cast(tf.equal(prediction, labels), dtype=tf.float32)) # Create pruning operator global_step = tf.train.get_or_create_global_step() pruning_hparams = pruning.get_pruning_hparams() pruning_hparams.sparsity_function_end_step = 1000 p = pruning.Pruning(pruning_hparams, global_step=global_step) mask_update_op = p.conditional_mask_update_op() p.add_pruning_summaries() # optimizer optimizer = tf.compat.v1.train.MomentumOptimizer(learning_rate=LEARNING_RATE, momentum=0.9) train_op = optimizer.minimize(loss, global_step) # loading data train_next = load_tfrecords('train') test_next = load_tfrecords('test') with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) # summaries logs_dir = './logs/with_pruning' if not os.path.exists(logs_dir): os.makedirs(logs_dir) tf.compat.v1.summary.scalar('monitor/loss', loss) tf.compat.v1.summary.scalar('monitor/acc', acc) merged_summary_op = tf.compat.v1.summary.merge_all() train_summary_writer = tf.compat.v1.summary.FileWriter(os.path.join(logs_dir, 'train'), graph=sess.graph) test_summary_writer = tf.compat.v1.summary.FileWriter(os.path.join(logs_dir, 'test'), graph=sess.graph) best_acc = 0 saver = tf.compat.v1.train.Saver() for epoch in range(NUM_EPOCHS): # training num_steps = TRAIN_SIZE // BATCH_SIZE train_acc = 0 train_loss = 0 for step in range(num_steps): x, y = sess.run(train_next) _, summary, train_acc_batch, train_loss_batch = sess.run([train_op, merged_summary_op, acc, loss], feed_dict={inputs: x, labels: y}) sess.run(mask_update_op) train_acc += train_acc_batch train_loss += train_loss_batch sys.stdout.write("\r epoch %d, step %d, training accuracy %g, training loss %g" % (epoch + 1, step + 1, train_acc_batch, train_loss_batch)) sys.stdout.flush() train_summary_writer.add_summary(summary, global_step=epoch * num_steps + step) train_summary_writer.flush() print("\n epoch %d, training accuracy %g, training loss %g" % (epoch + 1, train_acc / num_steps, train_loss / num_steps)) # testing num_steps = TEST_SIZE // BATCH_SIZE test_acc = 0 test_loss = 0 for step in range(num_steps): x, y = sess.run(test_next) summary, test_acc_batch, test_loss_batch = sess.run([merged_summary_op, acc, loss], feed_dict={inputs: x, labels: y}) test_acc += test_acc_batch test_loss += test_loss_batch test_summary_writer.add_summary(summary, global_step=(epoch * num_steps + step) * (TRAIN_SIZE // TEST_SIZE)) test_summary_writer.flush() print(" epoch %d, testing accuracy %g, testing loss %g" % (epoch + 1, test_acc / num_steps, test_loss / num_steps)) if test_acc / num_steps > best_acc: best_acc = test_acc / num_steps saver.save(sess, './ckpt_with_pruning/model') print(" Best Testing Accuracy %g" % best_acc)
def train(): with tf.Graph().as_default(): with tf.device('/gpu:' + str(GPU_INDEX)): pointclouds_pl = MODEL.placeholder_input(BATCH_SIZE, NUM_POINT) labels_pl = MODEL.placeholder_label(BATCH_SIZE) if not FLAGS.quantize_delay: is_training = tf.placeholder(tf.bool, shape=(), name="is_training") else: is_training = True # Note the global_step=batch parameter to minimize. # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains. batch = tf.Variable(0) # bn_decay = BN_INIT_DECAY bn_decay = get_bn_decay(batch) tf.summary.scalar('bn_decay', bn_decay) # Get model pred, end_points = MODEL.get_network(pointclouds_pl, is_training, bn_decay=bn_decay, dynamic=DYNAMIC, STN=STN, scale=SCALE, concat_fea=CONCAT) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) # Create a pruning object using the pruning specification p = pruning.Pruning(pruning_hparams, global_step=batch) # Add conditional mask update op. Executing this op will update all # the masks in the graph if the current global step is in the range # [begin_pruning_step, end_pruning_step] as specified by the pruning spec mask_update_op = p.conditional_mask_update_op() # Add summaries to keep track of the sparsity in different layers during training p.add_pruning_summaries() if FLAGS.quantize_delay and FLAGS.quantize_delay > 0: quant_scopes = ["DGCNN/get_edge_feature", "DGCNN/get_edge_feature_1", "DGCNN/get_edge_feature_2", "DGCNN/get_edge_feature_3", "DGCNN/get_edge_feature_4", "DGCNN/agg", "DGCNN/transform_net", "DGCNN/Transform", "DGCNN/dgcnn1", "DGCNN/dgcnn2", "DGCNN/dgcnn3", "DGCNN/dgcnn4", "PointNet"] tf.contrib.quantize.create_training_graph( quant_delay=FLAGS.quantize_delay) for scope in quant_scopes: my_quantization.experimental_create_training_graph(quant_delay=FLAGS.quantize_delay, scope=scope) # Get loss loss = MODEL.get_loss(pred, labels_pl, end_points) regularization_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) all_losses = [] all_losses.append(loss) all_losses.append(tf.add_n(regularization_losses)) total_loss = tf.add_n(all_losses) # tf.summary.scalar('loss', loss) tf.summary.scalar('loss', total_loss) correct = tf.equal(tf.argmax(pred, 1), tf.cast(labels_pl, tf.int64)) accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE) tf.summary.scalar('accuracy', accuracy) # if update_ops: # print("BN parameters: ", update_ops) # updates = tf.group(*update_ops) # train_step = control_flow_ops.with_dependencies([updates], batch) # Get training operator learning_rate = get_learning_rate(batch) tf.summary.scalar('learning_rate', learning_rate) if OPTIMIZER == 'momentum': optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM) elif OPTIMIZER == 'adam': optimizer = tf.train.AdamOptimizer(learning_rate) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies([tf.group(*update_ops)]): train_op = optimizer.minimize(total_loss, global_step=batch) # train_op = slim.learning.create_train_op(total_loss, optimizer) # Add ops to save and restore all the variables. saver = tf.train.Saver(max_to_keep=51) # Create a session config = tf.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True config.log_device_placement = False sess = tf.Session(config=config) # Add summary writers merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), sess.graph) test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) # Init variables init = tf.global_variables_initializer() # To fix the bug introduced in TF 0.12.1 as in # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1 sess.run(init) # sess.run(init, {is_training_pl: True}) if FLAGS.quantize_delay and FLAGS.quantize_delay > 0: ops = {'pointclouds_pl': pointclouds_pl, 'labels_pl': labels_pl, # 'is_training_pl': is_training, 'pred': pred, 'loss': loss, 'train_op': train_op, 'merged': merged, 'step': batch, # 'mask_update_op': mask_update_op } else: ops = {'pointclouds_pl': pointclouds_pl, 'labels_pl': labels_pl, 'is_training_pl': is_training, 'pred': pred, 'loss': loss, 'train_op': train_op, 'merged': merged, 'step': batch, # 'mask_update_op': mask_update_op } ever_best = 0 if CHECKPOINT: saver.restore(sess, CHECKPOINT) for epoch in range(MAX_EPOCH): log_string(('**** EPOCH %03d ****' % (epoch)) + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + '****') sys.stdout.flush() ma = train_one_epoch(sess, ops, train_writer) if not FLAGS.quantize_delay: ma = eval_one_epoch(sess, ops, test_writer) # Save the variables to disk. if ma > ever_best: save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt")) log_string("Model saved in file: %s" % save_path) ever_best = ma log_string("Current model mean accuracy: {}".format(ma)) log_string("Best model mean accuracy: {}".format(ever_best)) else: if epoch % 5 == 0: if CHECKPOINT: save_path = saver.save(sess, os.path.join(LOG_DIR, "model-r-{}.ckpt".format(str(epoch)))) else: save_path = saver.save(sess, os.path.join(LOG_DIR, "model-{}.ckpt".format(str(epoch)))) log_string("Model saved in file: %s" % save_path)
def __init__(self, hparams=None, mode='', seed=None, init_weight=0.01, dtype=tf.float32): self.encoder_input_data = tf.placeholder(tf.int32, [None, None], name='encoder_input_data') self.decoder_output_data = tf.placeholder(tf.int32, [None, None], name='decoder_output_data') self.tgt_vocab_size = hparams.tgt_vocab_size len_temp = tf.sign(tf.add(tf.abs(tf.sign(self.encoder_input_data)), 1)) self.seq_length_encoder_input_data = tf.cast( tf.reduce_sum(len_temp, -1), tf.int32) self.batch_size = tf.size(self.seq_length_encoder_input_data) self.num_layers = hparams.num_layers self.decoder_layer_num_more = hparams.decoder_layer_num_more self.unit_type = hparams.unit_type self.num_units = hparams.num_units self.dropout = hparams.dropout self.forget_bias = hparams.forget_bias self.attention_mode = hparams.attention_mode self.time_major = hparams.time_major self.residual = hparams.residual self.train_type = hparams.train_type self.mode = mode #l2 loss relate self.use_l2_loss = hparams.use_l2_loss self.l2_rate = hparams.l2_rate self.embedding_size = hparams.embedding_size self.dtype = dtype tf.get_variable_scope().set_initializer( tf.contrib.keras.initializers.glorot_normal(seed=None)) self.global_step = tf.train.get_or_create_global_step() #pruing paramter if _hm.NEED_PRUNING: pruning_hparams = pruning.get_pruning_hparams().parse( _hm.PRUNING_PARAMS) pruning_obj = pruning.Pruning(pruning_hparams, global_step=self.global_step) #embeding variable with tf.variable_scope('embedding_var') as scope: shape = [hparams.src_vocab_size, hparams.embedding_size] self.embedding_encoder = tf.Variable(tf.random_uniform( shape, -0.01, 0.01), dtype=tf.float32, name="embedding") self.embedding_decoder = self.embedding_encoder self.crf_transmit = tf.get_variable( "crf_transmit", [self.tgt_vocab_size, self.tgt_vocab_size], initializer=tf.random_normal_initializer(0., 512**-0.5)) res = self._build_graph() if (self.mode == _hm.MODE_TRAIN): self.loss = res[1] self.update, self.learning_rate = _mb.optimizer( hparams, self.loss, self.global_step) if _hm.NEED_PRUNING: self.mask_update_op = pruning_obj.conditional_mask_update_op() pruning_obj.add_pruning_summaries() #infer here else: logits = res[0] viterbi_sequence,_ =tf.contrib.crf.crf_decode(logits,\ self.crf_transmit,\ self.seq_length_encoder_input_data) self.neroutput = tf.identity(viterbi_sequence, name="NER_output") self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=hparams.saver_max_time) self.merged_summary = tf.summary.merge_all()
def train_function(pruning_method, loss, output_dir, use_tpu): """Training script for resnet model. Args: pruning_method: string indicating pruning method used to compress model. loss: tensor float32 of the cross entropy + regularization losses. output_dir: string tensor indicating the directory to save summaries. use_tpu: boolean indicating whether to run script on a tpu. Returns: host_call: summary tensors to be computed at each training step. train_op: the optimization term. """ global_step = tf.train.get_global_step() steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch) learning_rate = lr_schedule(current_epoch) optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if use_tpu: # use CrossShardOptimizer when using TPU. optimizer = contrib_tpu.CrossShardOptimizer(optimizer) # UPDATE_OPS needs to be added as a dependency due to batch norm update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops), tf.name_scope('train'): train_op = optimizer.minimize(loss, global_step) if not use_tpu: if FLAGS.num_workers > 0: optimizer = tf.train.SyncReplicasOptimizer( optimizer, replicas_to_aggregate=FLAGS.num_workers, total_num_replicas=FLAGS.num_workers) optimizer.make_session_run_hook(True) metrics = { 'global_step': tf.train.get_or_create_global_step(), 'loss': loss, 'learning_rate': learning_rate, 'current_epoch': current_epoch } if pruning_method == 'threshold': # construct the necessary hparams string from the FLAGS hparams_string = ('begin_pruning_step={0},' 'sparsity_function_begin_step={0},' 'end_pruning_step={1},' 'sparsity_function_end_step={1},' 'target_sparsity={2},' 'pruning_frequency={3},' 'threshold_decay=0,' 'use_tpu={4}'.format( FLAGS.sparsity_begin_step, FLAGS.sparsity_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, FLAGS.use_tpu, )) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) # The first layer has so few parameters, we don't need to prune it, and # pruning it a higher sparsity levels has very negative effects. if FLAGS.prune_first_layer and FLAGS.first_layer_sparsity >= 0.: pruning_hparams.set_hparam( 'weight_sparsity_map', ['resnet_model/initial_conv:%f' % FLAGS.first_layer_sparsity]) if FLAGS.prune_last_layer and FLAGS.last_layer_sparsity >= 0: pruning_hparams.set_hparam( 'weight_sparsity_map', ['resnet_model/final_dense:%f' % FLAGS.last_layer_sparsity]) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) # We override the train op to also update the mask. with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() masks = pruning.get_masks() metrics.update(utils.mask_summaries(masks)) elif pruning_method == 'scratch': masks = pruning.get_masks() # make sure the masks have the sparsity we expect and that it doesn't change metrics.update(utils.mask_summaries(masks)) elif pruning_method == 'variational_dropout': masks = utils.add_vd_pruning_summaries( threshold=FLAGS.log_alpha_threshold) metrics.update(masks) elif pruning_method == 'l0_regularization': summaries = utils.add_l0_summaries() metrics.update(summaries) elif pruning_method == 'baseline': pass else: raise ValueError('Unsupported pruning method', FLAGS.pruning_method) host_call = (functools.partial(utils.host_call_fn, output_dir), utils.format_tensors(metrics)) return host_call, train_op
def main(_): if not FLAGS.dataset_dir: raise ValueError( 'You must supply the dataset directory with --dataset_dir') tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): ####################### # Config model_deploy # ####################### deploy_config = model_deploy.DeploymentConfig( num_clones=FLAGS.num_clones, clone_on_cpu=FLAGS.clone_on_cpu, replica_id=FLAGS.task, num_replicas=FLAGS.worker_replicas, num_ps_tasks=FLAGS.num_ps_tasks) # Create global_step with tf.device(deploy_config.variables_device()): global_step = slim.create_global_step() ###################### # Select the dataset # ###################### dataset = dataset_factory.get_dataset(FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) ###################### # Select the network # ###################### network_fn = nets_factory.get_network_fn( FLAGS.model_name, num_classes=(dataset.num_classes - FLAGS.labels_offset), weight_decay=FLAGS.weight_decay, is_training=True) ##################################### # Select the preprocessing function # ##################################### preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True) ############################################################## # Create a dataset provider that loads data from the dataset # ############################################################## with tf.device(deploy_config.inputs_device()): provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=FLAGS.num_readers, common_queue_capacity=20 * FLAGS.batch_size, common_queue_min=10 * FLAGS.batch_size) [image, label] = provider.get(['image', 'label']) label -= FLAGS.labels_offset train_image_size = FLAGS.train_image_size or network_fn.default_image_size image = image_preprocessing_fn(image, train_image_size, train_image_size) images, labels = tf.train.batch( [image, label], batch_size=FLAGS.batch_size, num_threads=FLAGS.num_preprocessing_threads, capacity=5 * FLAGS.batch_size) labels = slim.one_hot_encoding( labels, dataset.num_classes - FLAGS.labels_offset) batch_queue = slim.prefetch_queue.prefetch_queue( [images, labels], capacity=2 * deploy_config.num_clones) #################### # Define the model # #################### def clone_fn(batch_queue): """Allows data parallelism by creating multiple clones of network_fn.""" images, labels = batch_queue.dequeue() logits, end_points = network_fn(images) ############################# # Specify the loss function # ############################# if 'AuxLogits' in end_points: slim.losses.softmax_cross_entropy( end_points['AuxLogits'], labels, label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') slim.losses.softmax_cross_entropy( logits, labels, label_smoothing=FLAGS.label_smoothing, weights=1.0) return end_points # Gather initial summaries. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]) first_clone_scope = deploy_config.clone_scope(0) # Gather update_ops from the first clone. These contain, for example, # the updates for the batch_norm variables created by network_fn. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope) # Add summaries for end_points. end_points = clones[0].outputs for end_point in end_points: x = end_points[end_point] summaries.add(tf.summary.histogram('activations/' + end_point, x)) summaries.add( tf.summary.scalar('sparsity/' + end_point, tf.nn.zero_fraction(x))) # Add summaries for losses. for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope): summaries.add(tf.summary.scalar('losses/%s' % loss.op.name, loss)) # Add summaries for variables. for variable in slim.get_model_variables(): summaries.add(tf.summary.histogram(variable.op.name, variable)) ################################# # Configure the moving averages # ################################# if FLAGS.moving_average_decay: moving_average_variables = slim.get_model_variables() variable_averages = tf.train.ExponentialMovingAverage( FLAGS.moving_average_decay, global_step) else: moving_average_variables, variable_averages = None, None if FLAGS.quantize_delay >= 0: tf.contrib.quantize.create_training_graph( quant_delay=FLAGS.quantize_delay) ######################################### # Configure the optimization procedure. # ######################################### with tf.device(deploy_config.optimizer_device()): learning_rate = _configure_learning_rate(dataset.num_samples, global_step) optimizer = _configure_optimizer(learning_rate) summaries.add(tf.summary.scalar('learning_rate', learning_rate)) if FLAGS.sync_replicas: # If sync_replicas is enabled, the averaging will be done in the chief # queue runner. optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=FLAGS.replicas_to_aggregate, total_num_replicas=FLAGS.worker_replicas, variable_averages=variable_averages, variables_to_average=moving_average_variables) elif FLAGS.moving_average_decay: # Update ops executed locally by trainer. update_ops.append( variable_averages.apply(moving_average_variables)) # Variables to train. variables_to_train = _get_variables_to_train() # and returns a train_tensor and summary_op total_loss, clones_gradients = model_deploy.optimize_clones( clones, optimizer, var_list=variables_to_train) # Add total_loss to summary. summaries.add(tf.summary.scalar('total_loss', total_loss)) # Create gradient updates. grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step) update_ops.append(grad_updates) update_op = tf.group(*update_ops) with tf.control_dependencies([update_op]): train_tensor = tf.identity(total_loss, name='train_op') # Add the summaries from the first clone. These contain the summaries # created by model_fn and either optimize_clones() or _gather_clone_loss(). summaries |= set( tf.get_collection(tf.GraphKeys.SUMMARIES, first_clone_scope)) # Merge all summaries together. summary_op = tf.summary.merge(list(summaries), name='summary_op') ########################### # Kicks off the training. # ########################### if FLAGS.pruning: global mask_update_op if FLAGS.pruning_hparams is not None: pruning_hparams = pruning.get_pruning_hparams().parse( FLAGS.pruning_hparams) pruning_obj = pruning.Pruning(pruning_hparams) else: pruning_obj = pruning.Pruning() pruning_obj.print_hparams() mask_update_op = pruning_obj.conditional_mask_update_op() slim.learning.train( train_tensor, logdir=FLAGS.train_dir, train_step_fn=train_step_with_pruning_fn, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None) else: slim.learning.train( train_tensor, logdir=FLAGS.train_dir, master=FLAGS.master, is_chief=(FLAGS.task == 0), init_fn=_get_init_fn(), summary_op=summary_op, number_of_steps=FLAGS.max_number_of_steps, log_every_n_steps=FLAGS.log_every_n_steps, save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs=FLAGS.save_interval_secs, sync_optimizer=optimizer if FLAGS.sync_replicas else None)
def model_bulid(self, height, width, channel, classes): x = tf.placeholder(dtype=tf.float32, shape=[None, height, width, channel]) y = tf.placeholder(dtype=tf.float32, shape=[None, classes]) # conv 1 ,if image Nx465x128x1 ,(conv 5x5 32 ,pool/2) conv1_1 = tf.nn.relu( self.conv_layer(x, ksize=[5, 5, channel, 32], stride=[1, 1, 1, 1], padding="SAME", name="conv1_1")) # Nx465x128x1 ==> Nx465x128x32 pool1_1 = self.pool_layer(conv1_1, ksize=[1, 2, 2, 1], stride=[1, 2, 2, 1], name="pool1_1") # N*232x64x32 # conv 2,(conv 5x5 32)=>(conv 5x5 64, pool/2) conv2_1 = tf.nn.relu( self.conv_layer(pool1_1, ksize=[5, 5, 32, 64], stride=[1, 1, 1, 1], padding="SAME", name="conv2_1")) pool2_1 = self.pool_layer(conv2_1, ksize=[1, 2, 2, 1], stride=[1, 2, 2, 1], name="pool2_1") # Nx116x32x128 # Flatten ft = self.flatten(pool2_1) # Dense layer,(fc 100)=>=>(fc classes) and prune optimize fc_layer1 = layers.masked_fully_connected(ft, 200) fc_layer2 = layers.masked_fully_connected(fc_layer1, 100) prediction = layers.masked_fully_connected(fc_layer2, 10) loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits_v2(logits=prediction, labels=y)) # original Dense layer # fc1 = self.fc_layer(ft,fc_dims=100,name="fc1") # finaloutput = self.finlaout_layer(fc1,fc_dims=10,name="final") # pruning op global_step = tf.train.get_or_create_global_step() reset_global_step_op = tf.assign(global_step, 0) # Get, Print, and Edit Pruning Hyperparameters pruning_hparams = pruning.get_pruning_hparams() print("Pruning Hyper parameters:", pruning_hparams) # Change hyperparameters to meet our needs pruning_hparams.begin_pruning_step = 0 pruning_hparams.end_pruning_step = 250 pruning_hparams.pruning_frequency = 1 pruning_hparams.sparsity_function_end_step = 250 pruning_hparams.target_sparsity = .9 # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam p = pruning.Pruning(pruning_hparams, global_step=global_step) prune_op = p.conditional_mask_update_op() # optimize LEARNING_RATE_BASE = 0.001 LEARNING_RATE_DECAY = 0.9 LEARNING_RATE_STEP = 300 gloabl_steps = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, gloabl_steps, LEARNING_RATE_STEP, LEARNING_RATE_DECAY, staircase=True) with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): optimize = tf.train.AdamOptimizer( learning_rate=learning_rate).minimize(loss, global_step) # prediction prediction_label = prediction correct_prediction = tf.equal(tf.argmax(prediction_label, 1), tf.argmax(y, 1)) accurary = tf.reduce_mean(tf.cast(correct_prediction, dtype=tf.float32)) correct_times_in_batch = tf.reduce_mean( tf.cast(correct_prediction, dtype=tf.int32)) return dict(x=x, y=y, optimize=optimize, correct_prediction=prediction_label, correct_times_in_batch=correct_times_in_batch, cost=loss, accurary=accurary, prune_op=prune_op)
def build_graph(self, hparams, scope=None): """Subclass must implement this method. Creates a sequence-to-sequence model with dynamic RNN decoder API. Args: hparams: Hyperparameter configurations. scope: VariableScope for the created subgraph; default "dynamic_seq2seq". Returns: A tuple of the form (logits, loss_tuple, final_context_state, sample_id), where: logits: float32 Tensor [batch_size x num_decoder_symbols]. loss: loss = the total loss / batch_size. final_context_state: the final state of decoder RNN. sample_id: sampling indices. Raises: ValueError: if encoder_type differs from mono and bi, or attention_option is not (luong | scaled_luong | bahdanau | normed_bahdanau). """ utils.print_out("# Creating %s graph ..." % self.mode) # Projection if not self.extract_encoder_layers: with tf.variable_scope(scope or "build_network"): with tf.variable_scope("decoder/output_projection"): if hparams.projection_type == 'sparse': self.output_layer = core_layers.MaskedFullyConnected( hparams.tgt_vocab_size, use_bias=False, name="output_projection") elif hparams.projection_type == 'dense': self.output_layer = tf.layers.Dense( hparams.tgt_vocab_size, use_bias=False, name="output_projection") else: raise ValueError("Unknown projection type %s!" % hparams.projection_type) with tf.variable_scope(scope or "dynamic_seq2seq", dtype=self.dtype): # Encoder if hparams.language_model: # no encoder for language modeling utils.print_out(" language modeling: no encoder") self.encoder_outputs = None encoder_state = None else: self.encoder_outputs, encoder_state = self._build_encoder( hparams) # Skip decoder if extracting only encoder layers if self.extract_encoder_layers: return # Decoder logits, decoder_cell_outputs, sample_id, final_context_state = ( self._build_decoder(self.encoder_outputs, encoder_state, hparams)) # Loss if self.mode != tf.contrib.learn.ModeKeys.INFER: with tf.device( model_helper.get_device_str( self.num_encoder_layers - 1, self.num_gpus)): loss = self._compute_loss(logits, decoder_cell_outputs) else: loss = tf.constant(0.0) # model pruning if hparams.pruning_hparams is not None: pruning_hparams = pruning.get_pruning_hparams().parse( hparams.pruning_hparams) self.p = pruning.Pruning(pruning_hparams, global_step=self.global_step) self.mask_update_op = self.p.conditional_mask_update_op() masks = get_masks() thresholds = get_thresholds() masks_s = [] for index, mask in enumerate(masks): masks_s.append( tf.summary.scalar(mask.name + '/sparsity', tf.nn.zero_fraction(mask))) masks_s.append( tf.summary.scalar( thresholds[index].op.name + '/threshold', thresholds[index])) masks_s.append( tf.summary.histogram(mask.name + '/mask_tensor', mask)) self.pruning_summary = tf.summary.merge([ tf.summary.scalar('sparsity', self.p._sparsity), tf.summary.scalar('last_mask_update_step', self.p._last_update_step) ] + masks_s) else: self.mask_update_op = tf.no_op() self.pruning_summary = tf.no_op() return logits, loss, final_context_state, sample_id
weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay) image_set = input_data.read_data_sets('~/tensor/AgeGenderDeepLearning-master/Folds/test-folds/gender_test_fold_is_3_DefaultRun', one_hot=True) #image = tf.placeholder(tf.float32, [None, 784]) #label = tf.placeholder(tf.float32, [None, 10]) layer1 = layers.masked_fully_connected(images, 512) layer2 = layers.masked_fully_connected(layer1, 512) logits = tf.nn.dropout(layer2, pkeep, name='drop1') batches = int(len(image_set.train.images) / batch_size) #loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=nlabels, logits=logits)) with tf.variable_scope("Prune_Layer", "Prune_Layer", [images]) as scope: pruning_hparams = pruning.get_pruning_hparams() print("Pruning Hyperparameters:", pruning_hparams) # Change hyperparameters to meet our needs pruning_hparams.begin_pruning_step = 0 pruning_hparams.end_pruning_step = 250 pruning_hparams.pruning_frequency = 1 pruning_hparams.sparsity_function_end_step = 250 pruning_hparams.target_sparsity = .9 global_step = tf.train.get_or_create_global_step() #train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step) reset_global_step_op = tf.assign(global_step, 0) p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9) prune_op = p.conditional_mask_update_op()
def train_alexnet( dataset_name='imagenet', prune=False, prune_params='', learning_rate=conf.learning_rate, num_epochs=conf.num_epochs, batch_size=conf.batch_size, learning_rate_decay_factor=conf.learning_rate_decay_factor, num_epochs_per_decay=conf.num_epochs_per_decay, dropout_rate=conf.dropout_rate, log_step=conf.log_step, checkpoint_step=conf.checkpoint_step, summary_path=conf.root_path + 'alexnet' + conf.summary_path, checkpoint_path=conf.root_path + 'alexnet' + conf.checkpoint_path, highest_accuracy_path=conf.root_path + 'alexnet' + conf.highest_accuracy_path, default_image_size=227, #224 in the paper ): """prune_params: Comma separated list of pruning-related hyperparameters ex:'begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000' """ if dataset_name is 'imagenet': num_class = conf.imagenet['num_class'] train_set_size = conf.imagenet['train_set_size'] validation_set_size = conf.imagenet['validation_set_size'] label_offset = conf.imagenet['label_offset'] label_path = conf.imagenet['label_path'] dataset_path = conf.imagenet['dataset_path'] x = tf.placeholder( tf.float32, [batch_size, default_image_size, default_image_size, 3]) y = tf.placeholder(tf.float32, [batch_size, num_class - label_offset]) keep_prob = tf.placeholder(tf.float32) #placeholder for dropout rate # prepare to train the model model = AlexNet.AlexNet(x, keep_prob, num_class - label_offset, [], prune=prune) # Link variable to model output score = model.fc8 # List of trainable variables of the layers we want to train var_list = [v for v in tf.trainable_variables()] # Op for calculating the loss with tf.name_scope("cross_ent"): loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=score, labels=y)) global_step = tf.Variable(0, False) with tf.name_scope("train"): # Get gradients of all trainable variables decay_steps = int(train_set_size / batch_size * num_epochs_per_decay) learning_rate = tf.train.exponential_decay( learning_rate, global_step, decay_steps, learning_rate_decay_factor, staircase=True) # Create optimizer and apply gradient descent to the trainable variables train_op = tf.train.GradientDescentOptimizer( learning_rate).minimize(loss, global_step) # Evaluation op: Accuracy of the model with tf.name_scope("accuracy"): correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) if prune: # Parse pruning hyperparameters prune_params = pruning.get_pruning_hparams().parse(prune_params) # Create a pruning object using the pruning specification p = pruning.Pruning(prune_params, global_step=global_step) # Add conditional mask update op. Executing this op will update all # the masks in the graph if the current global step is in the range # [begin_pruning_step, end_pruning_step] as specified by the pruning spec mask_update_op = p.conditional_mask_update_op() # Add summaries to keep track of the sparsity in different layers during training p.add_pruning_summaries() # Add the variables we train to the summary for var in var_list: tf.summary.histogram(var.name, var) # Add the loss to summary tf.summary.scalar('cross_entropy', loss) # Add the accuracy to the summary tf.summary.scalar('accuracy', accuracy) # Merge all summaries together merged_summary = tf.summary.merge_all() # Initialize the FileWriter writer = tf.summary.FileWriter(summary_path) # prepare the data img_train, label_train, labels_text_train = read_tfrecord( 'train', dataset_path, default_image_size=default_image_size) img_validation, label_validation, labels_text_validation = read_tfrecord( 'validation', dataset_path, default_image_size=default_image_size) coord = tf.train.Coordinator() # Initialize an saver for store model checkpoints saver = tf.train.Saver() with tf.Session() as sess: # Initialize all variables sess.run(tf.global_variables_initializer()) # Add the model graph to TensorBoard writer.add_graph(sess.graph) # Load the pretrained weights into the non-trainable layer model.load_initial_weights(sess) #start the input pipeline queue threads = tf.train.start_queue_runners(sess, coord=coord) # load the weights from checkpoint if there exists one model_saved = tf.train.get_checkpoint_state(checkpoint_path) if model_saved and model_saved.model_checkpoint_path: saver.restore(sess, model_saved.model_checkpoint_path) print('load model from ' + model_saved.model_checkpoint_path) print("{} Start training...".format(datetime.now())) print("{} Open Tensorboard at --logdir {}".format( datetime.now(), summary_path)) # Loop over number of epochs for epoch in range(num_epochs): print("{} Epoch number: {}".format(datetime.now(), epoch + 1)) highest_accuracy = 0 #highest accuracy by far if os.path.exists(highest_accuracy_path): f = open(highest_accuracy_path, 'r') highest_accuracy = float(f.read()) f.close() print('highest accuracy from previous training is %f' % highest_accuracy) train_batches_per_epoch = int( np.floor(train_set_size / batch_size)) for step in range(train_batches_per_epoch): # train the model img, l, l_text = sess.run( [img_train, label_train, labels_text_train]) _, sc, gl_step, lr = sess.run( [train_op, score, global_step, learning_rate], feed_dict={ x: img, y: l, keep_prob: dropout_rate }) if prune: # Update the masks by running the mask_update_op sess.run(mask_update_op) # Generate summary with the current batch of data and write to file if step % log_step == 0: s, aq = sess.run([merged_summary, accuracy], feed_dict={ x: img, y: l, keep_prob: 1. }) writer.add_summary( s, epoch * train_batches_per_epoch + step) print( "global_step:" + str(gl_step) + ';learning_rate:' + str(lr) + ';accuracy:', aq) #validate the model and write checkpoint if the accuracy is higher if step % checkpoint_step == 0 and step != 0: val_batches_per_epoch = int( np.floor(validation_set_size / batch_size)) print("{} Start validation".format(datetime.now())) test_acc = 0. test_count = 0 for _ in range(val_batches_per_epoch ): # val_batches_per_epoch #validate the model img, l, l_text = sess.run([ img_validation, label_validation, labels_text_validation ]) acc = sess.run(accuracy, feed_dict={ x: img, y: l, keep_prob: 1. }) test_acc += acc test_count += 1 test_acc /= test_count print("{} Validation Accuracy = {:.4f}".format( datetime.now(), test_acc)) # save the model if it is better than the previous best model if test_acc > highest_accuracy: print("{} Saving checkpoint of model...".format( datetime.now())) highest_accuracy = test_acc # save checkpoint of the model checkpoint_name = os.path.join( checkpoint_path, 'model_epoch' + '.ckpt') # save_path = saver.save(sess, checkpoint_name, global_step=global_step) f = open(highest_accuracy_path, 'w') f.write(str(highest_accuracy)) f.close() print("{} Model checkpoint saved at {}".format( datetime.now(), checkpoint_name)) coord.request_stop() coord.join(threads)
def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.contrib.framework.get_or_create_global_step() # Get images and labels for CIFAR-10. images, labels = cifar10.distorted_inputs() # Build a Graph that computes the logits predictions from the # inference model. logits = cifar10.inference(images) # Calculate loss. loss = cifar10.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) # Use the pruning_obj to add ops to the training graph to update the masks # The conditional_mask_update_op will update the masks only when the # training step is in [begin_pruning_step, end_pruning_step] specified in # the pruning spec proto mask_update_op = pruning_obj.conditional_mask_update_op() # Use the pruning_obj to add summaries to the graph to track the sparsity # of each of the layers pruning_obj.add_pruning_summaries() class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 def before_run(self, run_context): self._step += 1 self._start_time = time.time() return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): duration = time.time() - self._start_time loss_value = run_values.results if self._step % 10 == 0: num_examples_per_step = 128 examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op) # Update the masks mon_sess.run(mask_update_op)
iter = 0 finalPreds = np.empty((0, NUM_CLASSES)) ### Removes old Tensorboard event files ### allEventFiles = os.listdir('./logs/') for file in allEventFiles: os.remove('./logs/' + file) ################ PRUNING ##################### PARAM_LIST = [ "name=FFN_Pruning_Test", "pruning_frequency=10", "target_sparsity=0.5" ] TEST_HPARAMS = ",".join(PARAM_LIST) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(TEST_HPARAMS) #pruning_hparams = model_pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) # Create a pruning object using the pruning specification p = pruning.Pruning(pruning_hparams, global_step=global_step) # Add conditional mask update op. Executing this op will update all # the masks in the graph if the current global step is in the range # [begin_pruning_step, end_pruning_step] as specified by the pruning spec mask_update_op = p.conditional_mask_update_op() # Add summaries to keep track of the sparsity in different layers during training p.add_pruning_summaries() ### Data statistics ### tic = time.time()
def build_model(): """Builds graph for model to train with rewrites for quantization. """ g = tf.Graph() with g.as_default(), tf.device( tf.train.replica_device_setter(FLAGS.ps_tasks)): inputs, labels = hcl_input(is_training=True) #with slim.arg_scope(mobilenet_v1.mobilenet_v1_arg_scope(is_training=True)): logits, _ = mobilenet_v1_prune.mobilenet_v1( inputs, is_training=True, depth_multiplier=FLAGS.depth_multiplier, num_classes=FLAGS.num_classes) tf.losses.softmax_cross_entropy(labels, logits) # Call rewriter to produce graph with fake quant ops and folded batch norms # quant_delay delays start of quantization till quant_delay steps, allowing # for better model accuracy. if FLAGS.quantize: tf.contrib.quantize.create_training_graph( quant_delay=get_quant_delay()) total_loss = tf.losses.get_total_loss(name='total_loss') # Configure the learning rate using an exponential decay. num_epochs_per_decay = 2.5 hcl_size = 4650035 #3523535 decay_steps = int(hcl_size / FLAGS.batch_size * num_epochs_per_decay) global_step = tf.train.get_or_create_global_step() learning_rate = tf.train.exponential_decay( get_learning_rate(), global_step, #t1f.train.get_or_create_global_step(), decay_steps, _LEARNING_RATE_DECAY_FACTOR, staircase=True) opt = tf.train.GradientDescentOptimizer(learning_rate) # Get, Print, and Edit Pruning Hyperparameters pruning_hparams = pruning.get_pruning_hparams() #print("Pruning Hyperparameters:", pruning_hparams) # Change hyperparameters to meet our needs pruning_hparams.begin_pruning_step = 200000 #pruning_hparams.end_pruning_step = 250 #pruning_hparams.pruning_frequency = 1 #pruning_hparams.sparsity_function_end_step = 250 pruning_hparams.target_sparsity = .5 print("Pruning Hyperparameters:", pruning_hparams) # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.5) prune_op = p.conditional_mask_update_op() train_tensor = slim.learning.create_train_op(total_loss, optimizer=opt) slim.summaries.add_scalar_summary(total_loss, 'total_loss', 'losses') slim.summaries.add_scalar_summary(learning_rate, 'learning_rate', 'training') return g, [train_tensor, prune_op]
def train_fn(training_method, global_step, total_loss, train_dir, accuracy, top_5_accuracy): """Training script for resnet model. Args: training_method: specifies the method used to sparsify networks. global_step: the current step of training/eval. total_loss: tensor float32 of the cross entropy + regularization losses. train_dir: string specifying where directory where summaries are saved. accuracy: tensor float32 batch classification accuracy. top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes). Returns: hooks: summary tensors to be computed at each training step. eval_metrics: set to None during training. train_op: the optimization term. """ # Rougly drops at every 30k steps. boundaries = [30000, 60000, 90000] if FLAGS.training_steps_multiplier != 1.0: multiplier = FLAGS.training_steps_multiplier boundaries = [int(x * multiplier) for x in boundaries] tf.logging.info( 'Learning Rate boundaries are updated with multiplier:%.2f', multiplier) learning_rate = tf.train.piecewise_constant( global_step, boundaries, values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)], name='lr_schedule') optimizer = tf.train.MomentumOptimizer( learning_rate, momentum=FLAGS.momentum, use_nesterov=True) if training_method == 'set': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseSETOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif training_method == 'static': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseStaticOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif training_method == 'momentum': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseMomentumOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False) elif training_method == 'rigl': # We override the train op to also update the mask. optimizer = sparse_optimizers.SparseRigLOptimizer( optimizer, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False) elif training_method == 'snip': optimizer = sparse_optimizers.SparseSnipOptimizer( optimizer, mask_init_method=FLAGS.mask_init_method, default_sparsity=FLAGS.end_sparsity, use_tpu=False) elif training_method in ('scratch', 'baseline', 'prune'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) # Create the training op update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer.minimize(total_loss, global_step) if training_method == 'prune': # construct the necessary hparams string from the FLAGS hparams_string = ('begin_pruning_step={0},' 'sparsity_function_begin_step={0},' 'end_pruning_step={1},' 'sparsity_function_end_step={1},' 'target_sparsity={2},' 'pruning_frequency={3},' 'threshold_decay=0,' 'use_tpu={4}'.format( FLAGS.sparsity_begin_step, FLAGS.sparsity_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, False, )) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) tf.logging.info('starting mask update op') # We override the train op to also update the mask. with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() masks = pruning.get_masks() mask_metrics = utils.mask_summaries(masks) for name, tensor in mask_metrics.items(): tf.summary.scalar(name, tensor) tf.summary.scalar('learning_rate', learning_rate) tf.summary.scalar('accuracy', accuracy) tf.summary.scalar('total_loss', total_loss) tf.summary.scalar('top_5_accuracy', top_5_accuracy) # Logging drop_fraction if dynamic sparse training. if training_method in ('set', 'momentum', 'rigl', 'static'): tf.summary.scalar('drop_fraction', optimizer.drop_fraction) summary_op = tf.summary.merge_all() summary_hook = tf.train.SummarySaverHook( save_secs=300, output_dir=train_dir, summary_op=summary_op) hooks = [summary_hook] eval_metrics = None return hooks, eval_metrics, train_op
def main(unused_args): tf.set_random_seed(FLAGS.seed) tf.get_variable_scope().set_use_resource(True) np.random.seed(FLAGS.seed) # Load the MNIST data and set up an iterator. mnist_data = input_data.read_data_sets(FLAGS.mnist, one_hot=False, validation_size=0) train_images = mnist_data.train.images test_images = mnist_data.test.images if FLAGS.input_mask_path: reader = tf.train.load_checkpoint(FLAGS.input_mask_path) input_mask = reader.get_tensor('layer1/mask') indices = np.sum(input_mask, axis=1) != 0 train_images = train_images[:, indices] test_images = test_images[:, indices] dataset = tf.data.Dataset.from_tensor_slices( (train_images, mnist_data.train.labels.astype(np.int32))) num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0]) batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size) iterator = batched_dataset.make_one_shot_iterator() test_dataset = tf.data.Dataset.from_tensor_slices( (test_images, mnist_data.test.labels.astype(np.int32))) num_test_images = mnist_data.test.images.shape[0] test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images) test_iterator = test_dataset.make_one_shot_iterator() # Set up loss function. use_model_pruning = FLAGS.training_method != 'baseline' if FLAGS.network_type == 'fc': cross_entropy_train, _ = mnist_network_fc( iterator.get_next(), model_pruning=use_model_pruning) cross_entropy_test, accuracy_test = mnist_network_fc( test_iterator.get_next(), reuse=True, model_pruning=use_model_pruning) else: raise RuntimeError(FLAGS.network + ' is an unknown network type.') # Remove extra added ones. Current implementation adds the variables twice # to the collection. Improve this hacky thing. # TODO test the following with the convnet or any other network. if use_model_pruning: for k in ('masks', 'masked_weights', 'thresholds', 'kernel'): # del tf.get_collection_ref(k)[2] # del tf.get_collection_ref(k)[2] collection = tf.get_collection_ref(k) del collection[len(collection) // 2:] print(tf.get_collection_ref(k)) # Set up optimizer and update ops. global_step = tf.train.get_or_create_global_step() batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size if FLAGS.optimizer != 'adam': if not use_model_pruning: boundaries = [ int(round(s * batch_per_epoch)) for s in [60, 70, 80] ] else: boundaries = [ int(round(s * batch_per_epoch)) for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20] ] learning_rate = tf.train.piecewise_constant( global_step, boundaries, values=[ FLAGS.learning_rate / (3.**i) for i in range(len(boundaries) + 1) ]) else: learning_rate = FLAGS.learning_rate if FLAGS.optimizer == 'adam': opt = tf.train.AdamOptimizer(FLAGS.learning_rate) elif FLAGS.optimizer == 'momentum': opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum, use_nesterov=FLAGS.use_nesterov) elif FLAGS.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate) else: raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type') custom_sparsities = { 'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale, 'layer3': FLAGS.end_sparsity * 0 } if FLAGS.training_method == 'set': # We override the train op to also update the mask. opt = sparse_optimizers.SparseSETOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'static': # We override the train op to also update the mask. opt = sparse_optimizers.SparseStaticOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal) elif FLAGS.training_method == 'momentum': # We override the train op to also update the mask. opt = sparse_optimizers.SparseMomentumOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, grow_init=FLAGS.grow_init, drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False) elif FLAGS.training_method == 'rigl': # We override the train op to also update the mask. opt = sparse_optimizers.SparseRigLOptimizer( opt, begin_step=FLAGS.maskupdate_begin_step, end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init, frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction, drop_fraction_anneal=FLAGS.drop_fraction_anneal, initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False) elif FLAGS.training_method == 'snip': opt = sparse_optimizers.SparseSnipOptimizer( opt, mask_init_method=FLAGS.mask_init_method, default_sparsity=FLAGS.end_sparsity, custom_sparsity_map=custom_sparsities, use_tpu=False) elif FLAGS.training_method in ('scratch', 'baseline', 'prune'): pass else: raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method) train_op = opt.minimize(cross_entropy_train, global_step=global_step) if FLAGS.training_method == 'prune': hparams_string = ( 'begin_pruning_step={0},sparsity_function_begin_step={0},' 'end_pruning_step={1},sparsity_function_end_step={1},' 'target_sparsity={2},pruning_frequency={3},' 'threshold_decay={4}'.format(FLAGS.prune_begin_step, FLAGS.prune_end_step, FLAGS.end_sparsity, FLAGS.pruning_frequency, FLAGS.threshold_decay)) pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string) pruning_hparams.set_hparam( 'weight_sparsity_map', ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()]) print(pruning_hparams) pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) with tf.control_dependencies([train_op]): train_op = pruning_obj.conditional_mask_update_op() weight_sparsity_levels = pruning.get_weight_sparsity() global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks()) tf.summary.scalar('test_accuracy', accuracy_test) tf.summary.scalar('global_sparsity', global_sparsity) for k, v in zip(pruning.get_masks(), weight_sparsity_levels): tf.summary.scalar('sparsity/%s' % k.name, v) if FLAGS.training_method in ('prune', 'snip', 'baseline'): mask_init_op = tf.no_op() tf.logging.info('No mask is set, starting dense.') else: all_masks = pruning.get_masks() mask_init_op = sparse_utils.get_mask_init_fn(all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, custom_sparsities) if FLAGS.save_model: saver = tf.train.Saver() init_op = tf.global_variables_initializer() hyper_params_string = '_'.join([ FLAGS.network_type, str(FLAGS.batch_size), str(FLAGS.learning_rate), str(FLAGS.momentum), FLAGS.optimizer, str(FLAGS.l2_scale), FLAGS.training_method, str(FLAGS.prune_begin_step), str(FLAGS.prune_end_step), str(FLAGS.end_sparsity), str(FLAGS.pruning_frequency), str(FLAGS.seed) ]) tf.io.gfile.makedirs(FLAGS.save_path) filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt') merged_summary_op = tf.summary.merge_all() # Run session. if not use_model_pruning: with tf.Session() as sess: summary_writer = tf.summary.FileWriter( FLAGS.save_path, graph=tf.get_default_graph()) print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy') sess.run([init_op]) tic = time.time() with tf.io.gfile.GFile(filename, 'w') as outputfile: for i in range(FLAGS.num_epochs * num_batches): sess.run([train_op]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([ cross_entropy_test, accuracy_test, merged_summary_op ]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %.4f, %.4f, %.4f' % ( i // num_batches, epoch_time, loss, accuracy) print(log_str) print(log_str, file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) else: with tf.Session() as sess: summary_writer = tf.summary.FileWriter( FLAGS.save_path, graph=tf.get_default_graph()) log_str = ','.join([ 'Epoch', 'Iteration', 'Test loss', 'Test accuracy', 'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1' ]) sess.run(init_op) sess.run(mask_init_op) tic = time.time() mask_records = {} with tf.io.gfile.GFile(filename, 'w') as outputfile: print(log_str) print(log_str, file=outputfile) for i in range(FLAGS.num_epochs * num_batches): if (FLAGS.mask_record_frequency > 0 and i % FLAGS.mask_record_frequency == 0): mask_vals = sess.run(pruning.get_masks()) # Cast into bool to save space. mask_records[i] = [ a.astype(np.bool) for a in mask_vals ] sess.run([train_op]) weight_sparsity, global_sparsity_val = sess.run( [weight_sparsity_levels, global_sparsity]) if (i % num_batches) == (-1 % num_batches): epoch_time = time.time() - tic loss, accuracy, summary = sess.run([ cross_entropy_test, accuracy_test, merged_summary_op ]) # Write logs at every test iteration. summary_writer.add_summary(summary, i) log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % ( i // num_batches, i, loss, accuracy, global_sparsity_val, weight_sparsity[0], weight_sparsity[1]) print(log_str) print(log_str, file=outputfile) mask_vals = sess.run(pruning.get_masks()) if FLAGS.network_type == 'fc': sparsities, sizes = get_compressed_fc(mask_vals) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes)) print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' % (sparsities, sizes), file=outputfile) tic = time.time() if FLAGS.save_model: saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt')) if mask_records: np.save(os.path.join(FLAGS.save_path, 'mask_records'), mask_records)
def build_graph(reader, model, train_data_pattern, label_loss_fn=losses.CrossEntropyLoss(), batch_size=1000, base_learning_rate=0.01, learning_rate_decay_examples=1000000, learning_rate_decay=0.95, optimizer_class=tf.train.AdamOptimizer, clip_gradient_norm=1.0, regularization_penalty=1, num_readers=1, num_epochs=None): """Creates the Tensorflow graph. This will only be called once in the life of a training model, because after the graph is created the model will be restored from a meta graph file rather than being recreated. Args: reader: The data file reader. It should inherit from BaseReader. model: The core model (e.g. logistic or neural net). It should inherit from BaseModel. train_data_pattern: glob path to the training data files. label_loss_fn: What kind of loss to apply to the model. It should inherit from BaseLoss. batch_size: How many examples to process at a time. base_learning_rate: What learning rate to initialize the optimizer with. optimizer_class: Which optimization algorithm to use. clip_gradient_norm: Magnitude of the gradient to clip to. regularization_penalty: How much weight to give the regularization loss compared to the label loss. num_readers: How many threads to use for I/O operations. num_epochs: How many passes to make over the data. 'None' means an unlimited number of passes. """ global_step = tf.Variable(0, trainable=False, name="global_step") local_device_protos = device_lib.list_local_devices() gpus = [x.name for x in local_device_protos if x.device_type == 'GPU'] gpus = gpus[:FLAGS.num_gpu] #gpus = gpus[-1:] num_gpus = len(gpus) if num_gpus > 0: logging.info("Using the following GPUs to train: " + str(gpus)) num_towers = num_gpus device_string = '/gpu:%d' else: logging.info("No GPUs found. Training on CPU.") num_towers = 1 device_string = '/cpu:%d' learning_rate = tf.train.exponential_decay(base_learning_rate, global_step * batch_size * num_towers, learning_rate_decay_examples, learning_rate_decay, staircase=True) tf.summary.scalar('learning_rate', learning_rate) optimizer = optimizer_class(learning_rate) unused_video_id, model_input_raw, labels_batch, num_frames = ( get_input_data_tensors(reader, train_data_pattern.split(','), batch_size=batch_size * num_towers, num_readers=num_readers, num_epochs=num_epochs)) tf.summary.histogram("model/input_raw", model_input_raw) feature_dim = len(model_input_raw.get_shape()) - 1 model_input = tf.nn.l2_normalize(model_input_raw, feature_dim) tower_inputs = tf.split(model_input, num_towers) tower_labels = tf.split(labels_batch, num_towers) tower_num_frames = tf.split(num_frames, num_towers) tower_gradients = [] tower_predictions = [] tower_label_losses = [] tower_reg_losses = [] for i in range(num_towers): with tf.device(device_string % i): with (tf.variable_scope(("tower"), reuse=True if i > 0 else None)): with (slim.arg_scope( [slim.model_variable, slim.variable], device="/cpu:0" if num_gpus != 1 else "/gpu:0")): logging.info('building graph with ' + device_string % i) result = model.create_model(tower_inputs[i], num_frames=tower_num_frames[i], vocab_size=reader.num_classes, labels=tower_labels[i]) for variable in slim.get_model_variables(): tf.summary.histogram(variable.op.name, variable) predictions = result["predictions"] tower_predictions.append(predictions) if "loss" in result.keys(): label_loss = result["loss"] else: label_loss = label_loss_fn.calculate_loss( predictions, tower_labels[i]) if "regularization_loss" in result.keys(): reg_loss = result["regularization_loss"] else: reg_loss = tf.constant(0.0) reg_losses = tf.losses.get_regularization_losses() if reg_losses: reg_loss += tf.add_n(reg_losses) tower_reg_losses.append(reg_loss) # Adds update_ops (e.g., moving average updates in batch normalization) as # a dependency to the train_op. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if "update_ops" in result.keys(): update_ops += result["update_ops"] if update_ops: with tf.control_dependencies(update_ops): barrier = tf.no_op(name="gradient_barrier") with tf.control_dependencies([barrier]): label_loss = tf.identity(label_loss) tower_label_losses.append(label_loss) # Incorporate the L2 weight penalties etc. final_loss = regularization_penalty * reg_loss + label_loss gradients = optimizer.compute_gradients( final_loss, colocate_gradients_with_ops=False) tower_gradients.append(gradients) label_loss = tf.reduce_mean(tf.stack(tower_label_losses)) tf.summary.scalar("label_loss", label_loss) if regularization_penalty != 0: reg_loss = tf.reduce_mean(tf.stack(tower_reg_losses)) tf.summary.scalar("reg_loss", reg_loss) merged_gradients = utils.combine_gradients(tower_gradients) if clip_gradient_norm > 0: with tf.name_scope('clip_grads'): merged_gradients = utils.clip_gradient_norms( merged_gradients, clip_gradient_norm) train_op = optimizer.apply_gradients(merged_gradients, global_step=global_step) pruning_hparams = pruning.get_pruning_hparams().parse( FLAGS.pruning_hparams) p = pruning.Pruning(pruning_hparams, global_step=global_step) mask_update_op = p.conditional_mask_update_op() tf.add_to_collection("global_step", global_step) tf.add_to_collection("loss", label_loss) tf.add_to_collection("predictions", tf.concat(tower_predictions, 0)) tf.add_to_collection("input_batch_raw", model_input_raw) tf.add_to_collection("input_batch", model_input) tf.add_to_collection("num_frames", num_frames) tf.add_to_collection("labels", tf.cast(labels_batch, tf.float32)) tf.add_to_collection("train_op", train_op) tf.add_to_collection('mask_update_op', mask_update_op)