Пример #1
0
def train():
  """Train CIFAR-10 for a number of steps."""
  with tf.Graph().as_default():
    global_step = tf.train.get_or_create_global_step()

    # Get images and labels for CIFAR-10.
    # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
    # GPU and resulting in a slow down.
    with tf.device('/cpu:0'):
      images, labels = vggnet.inputs(False)

    # Build a Graph that computes the logits predictions from the
    # inference model.
    logits, tensor_list = vggnet.inference(images)

    # Calculate loss.
    loss = vggnet.loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_op, _ = vggnet.train(loss, tensor_list, global_step)

    class _LoggerHook(tf.train.SessionRunHook):
      """Logs loss and runtime."""

      def begin(self):
        self._step = -1
        self._start_time = time.time()

      def before_run(self, run_context):
        self._step += 1
        return tf.train.SessionRunArgs(loss)  # Asks for loss value.

      def after_run(self, run_context, run_values):
        if self._step % FLAGS.log_frequency == 0:
          current_time = time.time()
          duration = current_time - self._start_time
          self._start_time = current_time

          loss_value = run_values.results
          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
          sec_per_batch = float(duration / FLAGS.log_frequency)

          format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
          print (format_str % (datetime.now(), self._step, loss_value,
                               examples_per_sec, sec_per_batch))

    start = time.time()
    with tf.train.MonitoredTrainingSession(
        checkpoint_dir=FLAGS.train_dir,
        hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
               tf.train.NanTensorHook(loss),
               _LoggerHook()],
        config=tf.ConfigProto(
            log_device_placement=FLAGS.log_device_placement)) as mon_sess:
      while not mon_sess.should_stop():
        mon_sess.run(train_op)
    end = time.time()
    print (end - start)
def tower_loss(scope):
    """Calculate the total loss on a single tower running the CIFAR model.

  Args:
    scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'

  Returns:
     Tensor of shape [] containing the total loss for a batch of data
  """
    # Get images and labels for CIFAR-10.
    images, labels = vggnet.distorted_inputs()

    # Build inference Graph.
    logits = vggnet.inference(images)

    # Build the portion of the Graph calculating the losses. Note that we will
    # assemble the total_loss using a custom function below.
    _ = vggnet.loss(logits, labels)

    # Assemble all of the losses for the current tower only.
    losses = tf.get_collection('losses', scope)

    # Calculate the total loss for the current tower.
    total_loss = tf.add_n(losses, name='total_loss')

    # Compute the moving average of all individual losses and the total loss.
    loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
    # loss_averages_op = loss_averages.apply(losses + [total_loss])

    # Attach a scalar summary to all individual losses and the total loss; do the
    # same for the averaged version of the losses.
    for l in losses + [total_loss]:
        # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
        # session. This helps the clarity of presentation on tensorboard.
        loss_name = re.sub('%s_[0-9]*/' % vggnet.TOWER_NAME, '', l.op.name)
        # Name each loss as '(raw)' and name the moving average version of the loss
        # as the original loss name.
        tf.summary.scalar(loss_name + ' (raw)', l)
        # tf.summary.scalar(loss_name, loss_averages.average(l))

    # with tf.control_dependencies([loss_averages_op]):
    total_loss = tf.identity(total_loss)
    return total_loss
Пример #3
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default() as g:
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = vggnet.distorted_inputs()

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits, tensor_list = vggnet.inference(images)

        # Calculate loss.
        loss = vggnet.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op, retrieve_list = vggnet.train(loss, tensor_list, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        class _SparsityHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                mode = sparsity_monitor.Mode.monitor
                data_format = "NHWC"
                self.monitor = sparsity_monitor.SparsityMonitor(mode, data_format, FLAGS.monitor_interval,\
                                                                FLAGS.monitor_period, retrieve_list)

            def before_run(self, run_context):
                self._step += 1
                selected_list = self.monitor.scheduler_before(self._step)
                return tf.train.SessionRunArgs(
                    selected_list)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                self.monitor.scheduler_after(run_values.results, self._step,
                                             os.getcwd(), FLAGS.file_io)

        sparsity_summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.sparsity_dir, g)

        start = time.time()
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=[
                    tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                    tf.train.NanTensorHook(loss),
                    tf.train.SummarySaverHook(save_steps=1,
                                              summary_writer=summary_writer,
                                              summary_op=sparsity_summary_op),
                    _LoggerHook(),
                    _SparsityHook()
                ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
        end = time.time()
        print(end - start)
Пример #4
0
def train():
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default() as g:
        global_step = tf.train.get_or_create_global_step()

        # Get images and labels for CIFAR-10.
        # Force input pipeline to CPU:0 to avoid operations sometimes ending up on
        # GPU and resulting in a slow down.
        with tf.device('/cpu:0'):
            images, labels = vggnet.inputs(False)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits, tensor_list = vggnet.inference(images)

        # Calculate loss.
        loss = vggnet.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op, retrieve_list = vggnet.train(loss, tensor_list, global_step)

        class _LoggerHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._step = -1
                self._start_time = time.time()

            def before_run(self, run_context):
                self._step += 1
                return tf.train.SessionRunArgs(loss)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                if self._step % FLAGS.log_frequency == 0:
                    current_time = time.time()
                    duration = current_time - self._start_time
                    self._start_time = current_time

                    loss_value = run_values.results
                    examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
                    sec_per_batch = float(duration / FLAGS.log_frequency)

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str % (datetime.now(), self._step, loss_value,
                                        examples_per_sec, sec_per_batch))

        class _SparsityHook(tf.train.SessionRunHook):
            """Logs loss and runtime."""
            def begin(self):
                self._bs_util = block_sparsity_util.BlockSparsityUtil(
                    FLAGS.block_size)
                self._internal_index_keeper = collections.OrderedDict()
                self._local_step = collections.OrderedDict()
                #self._fig, self._ax = plt.subplots()

            def before_run(self, run_context):
                self.selected_list = []
                for tensor_tuple in retrieve_list:
                    self.selected_list.append(tensor_tuple[0])
                    self.selected_list.append(tensor_tuple[1])
                return tf.train.SessionRunArgs(
                    self.selected_list)  # Asks for loss value.

            def after_run(self, run_context, run_values):
                self._data_list = []
                self._sparsity_list = []
                for i in range(len(run_values.results)):
                    if i % 2 == 0:
                        # tensor
                        self._data_list.append(run_values.results[i])
                    if i % 2 == 1:
                        # sparsity
                        self._sparsity_list.append(run_values.results[i])
                assert len(self._sparsity_list) == len(self.selected_list) / 2
                assert len(self._data_list) == len(self.selected_list) / 2
                num_data = len(self._data_list)
                format_str = (
                    'local_step: %d %s: sparsity = %.2f difference percentage = %.2f'
                )
                zero_block_format_str = (
                    'local_step: %d %s: zero block ratio = %.2f')
                for i in range(num_data):
                    sparsity = self._sparsity_list[i]
                    shape = self.selected_list[2 * i].get_shape()
                    tensor_name = self.selected_list[2 * i].name
                    batch_idx = 0
                    channel_idx = 0
                    if tensor_name in self._local_step:
                        if self._local_step[tensor_name] == FLAGS.monitor_interval and \
                           FLAGS.log_animation:
                            fig, ax = plt.subplots()
                            ani = animation.FuncAnimation(
                                fig,
                                animate,
                                frames=FLAGS.monitor_interval,
                                fargs=(
                                    ax,
                                    tensor_name,
                                ),
                                interval=500,
                                repeat=False,
                                blit=True)

                            figure_name = tensor_name.replace('/',
                                                              '_').replace(
                                                                  ':', '_')
                            ani.save(figure_name + '.gif',
                                     dpi=80,
                                     writer='imagemagick')
                            self._local_step[tensor_name] += 1
                            continue
                        if self._local_step[
                                tensor_name] >= FLAGS.monitor_interval:
                            continue
                    if tensor_name not in self._local_step and sparsity > FLAGS.sparsity_threshold:
                        self._local_step[tensor_name] = 0
                        zero_block_ratio = self._bs_util.zero_block_ratio_matrix(
                            self._data_list[i], shape)
                        print(zero_block_format_str %
                              (self._local_step[tensor_name], tensor_name,
                               zero_block_ratio))
                        print(format_str % (self._local_step[tensor_name],
                                            tensor_name, sparsity, 0.0))
                        self._internal_index_keeper[
                            tensor_name] = get_non_zero_index(
                                self._data_list[i], shape)
                        if tensor_name not in data_dict:
                            data_dict[tensor_name] = []
                        data_dict[tensor_name].append(
                            feature_map_extraction(self._data_list[i],
                                                   batch_idx, channel_idx))
                        self._local_step[tensor_name] += 1
                    elif tensor_name in self._local_step and self._local_step[
                            tensor_name] > 0:
                        # Inside the monitoring interval
                        zero_block_ratio = self._bs_util.zero_block_ratio_matrix(
                            self._data_list[i], shape)
                        print(zero_block_format_str %
                              (self._local_step[tensor_name], tensor_name,
                               zero_block_ratio))
                        data_length = self._data_list[i].size
                        #local_index_list = get_non_zero_index(self._data_list[i], shape)
                        #diff_percentage = calc_index_diff_percentage(local_index_list,
                        #  self._internal_index_keeper[tensor_name], sparsity, data_length)
                        #self._internal_index_keeper[tensor_name] = local_index_list
                        print(format_str % (
                            self._local_step[tensor_name],
                            tensor_name,
                            #sparsity, diff_percentage))
                            sparsity,
                            0.0))
                        data_dict[tensor_name].append(
                            feature_map_extraction(self._data_list[i],
                                                   batch_idx, channel_idx))
                        self._local_step[tensor_name] += 1
                    else:
                        continue

        sparsity_summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(FLAGS.sparsity_dir, g)

        start = time.time()
        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=FLAGS.train_dir,
                hooks=
            [
                tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
                tf.train.NanTensorHook(loss),
                #             tf.train.SummarySaverHook(save_steps=FLAGS.log_frequency, summary_writer=summary_writer, summary_op=sparsity_summary_op),
                _LoggerHook(),
                _SparsityHook()
            ],
                config=tf.ConfigProto(log_device_placement=FLAGS.
                                      log_device_placement)) as mon_sess:
            while not mon_sess.should_stop():
                mon_sess.run(train_op)
        end = time.time()
        print(end - start)