def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.contrib.framework.get_or_create_global_step() # Get images and labels for CIFAR-10. images, labels = cifar10.distorted_inputs() print(images.shape) # Build a Graph that computes the logits predictions from the # inference model. logits = cifar10.inference(images) print ("logits shape:", logits.shape) # Calculate loss. print ("label shape", labels.shape) # Calculate loss. loss = cifar10.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) # Use the pruning_obj to add ops to the training graph to update the masks # The conditional_mask_update_op will update the masks only when the # training step is in [begin_pruning_step, end_pruning_step] specified in # the pruning spec proto mask_update_op = pruning_obj.conditional_mask_update_op() # Use the pruning_obj to add summaries to the graph to track the sparsity # of each of the layers pruning_obj.add_pruning_summaries() class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 def before_run(self, run_context): self._step += 1 self._start_time = time.time() return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): duration = time.time() - self._start_time loss_value = run_values.results if self._step % 10 == 0: num_examples_per_step = 128 examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op) # Update the masks mon_sess.run(mask_update_op)
def train(): """Train LeNet for a number of steps.""" with tf.Graph().as_default(): global_step = tf.contrib.framework.get_or_create_global_step() dataset = mnist.get_split('train', './tmp/LeNet_data') # Creates a TF-Slim DataProvider which reads the dataset in the background # during both training and testing. provider = slim.dataset_data_provider.DatasetDataProvider( dataset, num_readers=10, shuffle=True) image, label = provider.get(['image', 'label']) # batch up some training data images, labels = tf.train.batch([image, label], batch_size=LeNet.BATCH_SIZE) print(images.shape) images = tf.cast(images, tf.float32) # images=tf.transpose(images,[1,2,3,0])#tf.reshape(images,[28,28,1,64]) print(images.shape) # labels=tf.reshape(labels,[128,]) # print (images.shape) logits = LeNet.inference(images) print("logits shape:", logits.shape) # Calculate loss. print("label shape", labels.shape) loss = LeNet.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = LeNet.train(loss, global_step) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse( FLAGS.pruning_hparams) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) # Use the pruning_obj to add ops to the training graph to update the masks # The conditional_mask_update_op will update the masks only when the # training step is in [begin_pruning_step, end_pruning_step] specified in # the pruning spec proto mask_update_op = pruning_obj.conditional_mask_update_op() # Use the pruning_obj to add summaries to the graph to track the sparsity # of each of the layers pruning_obj.add_pruning_summaries() class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 def before_run(self, run_context): self._step += 1 self._start_time = time.time() return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): duration = time.time() - self._start_time loss_value = run_values.results if self._step % 10 == 0: num_examples_per_step = 128 examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ( '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[ tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook() ], config=tf.ConfigProto(log_device_placement=FLAGS. log_device_placement)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op) # Update the masks mon_sess.run(mask_update_op)
def train(): """Train CIFAR-10 for a number of steps.""" with tf.Graph().as_default(): global_step = tf.contrib.framework.get_or_create_global_step() # images, labels = vgg.distorted_inputs() dataset = imagenet.get_split('train', '/data/ramyadML/TF-slim-data/imageNet/processed') # Creates a TF-Slim DataProvider which reads the dataset in the background # during both training and testing. provider = slim.dataset_data_provider.DatasetDataProvider(dataset, num_readers=4, common_queue_capacity=20*32, common_queue_min=10*32, shuffle=True) preprocessing_name = 'vgg_16' image_preprocessing_fn = preprocessing_factory.get_preprocessing( preprocessing_name, is_training=True) [image, label] = provider.get(['image', 'label']) image = image_preprocessing_fn(image, 224, 224) label -= 1 # batch up some training data images, labels = tf.train.batch([image, label], batch_size=32, num_threads=4, capacity=5*32) print (images.shape) images = tf.cast(images, tf.float32) # Build a Graph that computes the logits predictions from the # inference model. logits = vgg.inference(images) print ("logits shape:", logits.shape) # Calculate loss. print ("label shape", labels.shape) # Calculate loss. loss = vgg.loss(logits, labels) # Save list_var_names = [ 'vgg_16/conv1/conv1_1/biases', 'vgg_16/conv1/conv1_1/weights', 'vgg_16/conv1/conv1_2/biases', 'vgg_16/conv1/conv1_2/weights', 'vgg_16/conv2/conv2_1/biases', 'vgg_16/conv2/conv2_1/weights', 'vgg_16/conv2/conv2_2/biases', 'vgg_16/conv2/conv2_2/weights', 'vgg_16/conv3/conv3_1/biases', 'vgg_16/conv3/conv3_1/weights', 'vgg_16/conv3/conv3_2/biases', 'vgg_16/conv3/conv3_2/weights', 'vgg_16/conv3/conv3_3/biases', 'vgg_16/conv3/conv3_3/weights', 'vgg_16/conv4/conv4_1/biases', 'vgg_16/conv4/conv4_1/weights', 'vgg_16/conv4/conv4_2/biases', 'vgg_16/conv4/conv4_2/weights', 'vgg_16/conv4/conv4_3/biases', 'vgg_16/conv4/conv4_3/weights', 'vgg_16/conv5/conv5_1/biases', 'vgg_16/conv5/conv5_1/weights', 'vgg_16/conv5/conv5_2/biases', 'vgg_16/conv5/conv5_2/weights', 'vgg_16/conv5/conv5_3/biases', 'vgg_16/conv5/conv5_3/weights', 'vgg_16/fc6/biases', 'vgg_16/fc6/weights', 'vgg_16/fc7/biases', 'vgg_16/fc7/weights', 'vgg_16/fc8/biases', 'vgg_16/fc8/weights'] var_list_to_restore = [] for name in list_var_names: var_list_to_restore = var_list_to_restore + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, name) saver = tf.train.Saver(var_list_to_restore) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = vgg.train(loss, global_step) # Parse pruning hyperparameters pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) # Create a pruning object using the pruning hyperparameters pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) # Use the pruning_obj to add ops to the training graph to update the masks # The conditional_mask_update_op will update the masks only when the # training step is in [begin_pruning_step, end_pruning_step] specified in # the pruning spec proto mask_update_op = pruning_obj.conditional_mask_update_op() # Use the pruning_obj to add summaries to the graph to track the sparsity # of each of the layers pruning_obj.add_pruning_summaries() class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 def before_run(self, run_context): self._step += 1 self._start_time = time.time() return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): duration = time.time() - self._start_time loss_value = run_values.results if self._step % 10 == 0: num_examples_per_step = 128 examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print(format_str % (datetime.datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess: saver.restore(mon_sess,"trained_weights/vgg_16.ckpt") while not mon_sess.should_stop(): mon_sess.run(train_op) # Update the masks mon_sess.run(mask_update_op)