Exemplo n.º 1
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    images, labels = cifar10.distorted_inputs()
    print(images.shape)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = cifar10.inference(images)

    print ("logits shape:", logits.shape)
    # Calculate loss.
    print ("label shape", labels.shape)
    # Calculate loss.
    loss = cifar10.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = cifar10.train(loss, global_step)

    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    # Use the pruning_obj to add ops to the training graph to update the masks
    # The conditional_mask_update_op will update the masks only when the
    # training step is in [begin_pruning_step, end_pruning_step] specified in
    # the pruning spec proto
    mask_update_op = pruning_obj.conditional_mask_update_op()

    # Use the pruning_obj to add summaries to the graph to track the sparsity
    # of each of the layers
    pruning_obj.add_pruning_summaries()


    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1

      def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        loss_value = run_values.results
        if self._step % 10 == 0:
          num_examples_per_step = 128
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print(format_str % (datetime.datetime.now(), self._step, loss_value,
                              examples_per_sec, sec_per_batch))

    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
        # Update the masks
        mon_sess.run(mask_update_op)
Exemplo n.º 2
0
def train():
    """Train LeNet for a number of steps."""
    with tf.Graph().as_default():
        global_step = tf.contrib.framework.get_or_create_global_step()

        dataset = mnist.get_split('train', './tmp/LeNet_data')

        # Creates a TF-Slim DataProvider which reads the dataset in the background
        # during both training and testing.
        provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset, num_readers=10, shuffle=True)
        image, label = provider.get(['image', 'label'])
        # batch up some training data
        images, labels = tf.train.batch([image, label],
                                        batch_size=LeNet.BATCH_SIZE)
        print(images.shape)

        images = tf.cast(images, tf.float32)
        # images=tf.transpose(images,[1,2,3,0])#tf.reshape(images,[28,28,1,64])
        print(images.shape)
        # labels=tf.reshape(labels,[128,])
        # print (images.shape)
        logits = LeNet.inference(images)
        print("logits shape:", logits.shape)
        # Calculate loss.
        print("label shape", labels.shape)
        loss = LeNet.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = LeNet.train(loss, global_step)

        # Parse pruning hyperparameters
        pruning_hparams = pruning.get_pruning_hparams().parse(
            FLAGS.pruning_hparams)

        # Create a pruning object using the pruning hyperparameters
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

        # Use the pruning_obj to add ops to the training graph to update the masks
        # The conditional_mask_update_op will update the masks only when the
        # training step is in [begin_pruning_step, end_pruning_step] specified in
        # the pruning spec proto
        mask_update_op = pruning_obj.conditional_mask_update_op()

        # Use the pruning_obj to add summaries to the graph to track the sparsity
        # of each of the layers
        pruning_obj.add_pruning_summaries()

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1

            def before_run(self, run_context):
                self._step += 1
                self._start_time = time.time()
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                duration = time.time() - self._start_time
                loss_value = run_values.results
                if self._step % 10 == 0:
                    num_examples_per_step = 128
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str %
                          (datetime.datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    _LoggerHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
                # Update the masks
                mon_sess.run(mask_update_op)
Exemplo n.º 3
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.contrib.framework.get_or_create_global_step()

    # images, labels = vgg.distorted_inputs()
    dataset = imagenet.get_split('train', '/data/ramyadML/TF-slim-data/imageNet/processed')

    # Creates a TF-Slim DataProvider which reads the dataset in the background
    # during both training and testing.
    provider = slim.dataset_data_provider.DatasetDataProvider(dataset,
                                                              num_readers=4,
                                                              common_queue_capacity=20*32,
                                                              common_queue_min=10*32,
                                                              shuffle=True)


    preprocessing_name = 'vgg_16'
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
                            preprocessing_name,
                            is_training=True)

    [image, label] = provider.get(['image', 'label'])
    image = image_preprocessing_fn(image, 224, 224)
    label -= 1

    # batch up some training data
    images, labels = tf.train.batch([image, label], 
                                    batch_size=32,
                                    num_threads=4,
                                    capacity=5*32)

    print (images.shape)


    images = tf.cast(images, tf.float32)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits = vgg.inference(images)

    print ("logits shape:", logits.shape)
    # Calculate loss.
    print ("label shape", labels.shape)
    # Calculate loss.
    loss = vgg.loss(logits, labels)


    # Save
    list_var_names = [  'vgg_16/conv1/conv1_1/biases',
                    	'vgg_16/conv1/conv1_1/weights',
                    	'vgg_16/conv1/conv1_2/biases',
                    	'vgg_16/conv1/conv1_2/weights',
                    	'vgg_16/conv2/conv2_1/biases',
                    	'vgg_16/conv2/conv2_1/weights',
                    	'vgg_16/conv2/conv2_2/biases',
                    	'vgg_16/conv2/conv2_2/weights',
                    	'vgg_16/conv3/conv3_1/biases',
                    	'vgg_16/conv3/conv3_1/weights',
                    	'vgg_16/conv3/conv3_2/biases',
                    	'vgg_16/conv3/conv3_2/weights',
                    	'vgg_16/conv3/conv3_3/biases',
                    	'vgg_16/conv3/conv3_3/weights',
                    	'vgg_16/conv4/conv4_1/biases',
                    	'vgg_16/conv4/conv4_1/weights',
                    	'vgg_16/conv4/conv4_2/biases',
                    	'vgg_16/conv4/conv4_2/weights',
                    	'vgg_16/conv4/conv4_3/biases',
                    	'vgg_16/conv4/conv4_3/weights',
                    	'vgg_16/conv5/conv5_1/biases',
                    	'vgg_16/conv5/conv5_1/weights',
                    	'vgg_16/conv5/conv5_2/biases',
                    	'vgg_16/conv5/conv5_2/weights',
                    	'vgg_16/conv5/conv5_3/biases',
                    	'vgg_16/conv5/conv5_3/weights',
                    	'vgg_16/fc6/biases',
                    	'vgg_16/fc6/weights',
                    	'vgg_16/fc7/biases',
                    	'vgg_16/fc7/weights',
                    	'vgg_16/fc8/biases',
                    	'vgg_16/fc8/weights']

    var_list_to_restore = []
 
    for name in list_var_names:
        var_list_to_restore = var_list_to_restore + tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, name)

    saver = tf.train.Saver(var_list_to_restore)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op = vgg.train(loss, global_step)

    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    # Use the pruning_obj to add ops to the training graph to update the masks
    # The conditional_mask_update_op will update the masks only when the
    # training step is in [begin_pruning_step, end_pruning_step] specified in
    # the pruning spec proto
    mask_update_op = pruning_obj.conditional_mask_update_op()

    # Use the pruning_obj to add summaries to the graph to track the sparsity
    # of each of the layers
    pruning_obj.add_pruning_summaries()


    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1

      def before_run(self, run_context):
        self._step += 1
        self._start_time = time.time()
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        duration = time.time() - self._start_time
        loss_value = run_values.results
        if self._step % 10 == 0:
          num_examples_per_step = 128
          examples_per_sec = num_examples_per_step / duration
          sec_per_batch = float(duration)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print(format_str % (datetime.datetime.now(), self._step, loss_value,
                              examples_per_sec, sec_per_batch))


    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
                config=tf.ConfigProto(
                log_device_placement=FLAGS.log_device_placement)) as mon_sess:

      saver.restore(mon_sess,"trained_weights/vgg_16.ckpt")
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
        # Update the masks
        mon_sess.run(mask_update_op)