Beispiel #1
0
  def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4,
                   freq_iter=2):
    """Setups a trivial training procedure for sparse training."""
    tf.reset_default_graph()
    optim = tf.train.GradientDescentOptimizer(0.1)
    sparse_optim = sparse_optimizers.SparseStaticOptimizer(
        optim, start_iter, end_iter, freq_iter, drop_fraction=drop_frac)
    x = tf.random.uniform((1, n_inp))
    y = layers.masked_fully_connected(x, n_out, activation_fn=None)
    global_step = tf.train.get_or_create_global_step()
    weight = pruning.get_weights()[0]
    # There is one masked layer to be trained.
    mask = pruning.get_masks()[0]
    # Around half of the values of the mask is set to zero with `mask_update`.
    mask_update = tf.assign(
        mask,
        tf.constant(
            np.random.choice([0, 1], size=(n_inp, n_out), p=[1./2, 1./2]),
            dtype=tf.float32))
    loss = tf.reduce_mean(y)
    global_step = tf.train.get_or_create_global_step()
    train_op = sparse_optim.minimize(loss, global_step)

    # Init
    sess = tf.Session()
    init = tf.global_variables_initializer()
    sess.run(init)
    sess.run([mask_update])

    return sess, train_op, mask, weight, global_step
Beispiel #2
0
 def snip_op():
     all_masks = pruning.get_masks()
     assigner = sparse_utils.get_mask_init_fn(all_masks,
                                              self._mask_init_method,
                                              self._default_sparsity,
                                              self._custom_sparsity_map,
                                              mask_fn=snip_fn)
     with ops.control_dependencies([assigner]):
         assign_op = state_ops.assign(self.is_snipped,
                                      True,
                                      name='assign_true_after_snipped')
     return assign_op
Beispiel #3
0
def check_global_sparsity():
    """Add a summary for the weight sparsity."""
    weight_masks = magnitude_pruning.get_masks()
    weights_per_layer = []
    nonzero_per_layer = []
    for mask in weight_masks:
        nonzero_per_layer.append(tf.reduce_sum(mask))
        weights_per_layer.append(tf.size(mask))
        total_nonzero = tf.add_n(nonzero_per_layer)
        total_weights = tf.add_n(weights_per_layer)
    sparsity = (1.0 - (tf.cast(total_nonzero, tf.float32) /
                       tf.cast(total_weights, tf.float32)))
    tf.summary.scalar("global_weight_sparsity", sparsity)
Beispiel #4
0
 def scaffold_fn():
   """For initialization, passed to the estimator."""
   if FLAGS.initial_value_checkpoint:
     initialize_parameters_from_ckpt(FLAGS.initial_value_checkpoint)
   all_masks = pruning.get_masks()
   assigner = sparse_utils.get_mask_init_fn(
       all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity,
       CUSTOM_SPARSITY_MAP)
   def init_fn(scaffold, session):
     """A callable for restoring variable from a checkpoint."""
     del scaffold  # Unused.
     session.run(assigner)
   return tf.train.Scaffold(init_fn=init_fn)
Beispiel #5
0
  def testMaskedLSTMCell(self):
    expected_num_masks = 1
    expected_num_rows = 2 * self.dim
    expected_num_cols = 4 * self.dim
    with self.cached_session():
      inputs = variables.Variable(
          random_ops.random_normal([self.batch_size, self.dim]))
      c = variables.Variable(
          random_ops.random_normal([self.batch_size, self.dim]))
      h = variables.Variable(
          random_ops.random_normal([self.batch_size, self.dim]))
      state = tf_rnn_cells.LSTMStateTuple(c, h)
      lstm_cell = rnn_cells.MaskedLSTMCell(self.dim)
      lstm_cell(inputs, state)
      self.assertEqual(len(pruning.get_masks()), expected_num_masks)
      self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks)
      self.assertEqual(len(pruning.get_thresholds()), expected_num_masks)
      self.assertEqual(len(pruning.get_weights()), expected_num_masks)

      for mask in pruning.get_masks():
        self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols))
      for weight in pruning.get_weights():
        self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols))
Beispiel #6
0
    def metric_fn(labels, logits, cross_loss, reg_loss):
      """Calculate eval metrics."""
      logging.info('In metric function')
      eval_metrics = {}
      predictions = tf.cast(tf.argmax(logits, axis=1), tf.int32)
      in_top_5 = tf.cast(tf.nn.in_top_k(logits, labels, 5), tf.float32)
      eval_metrics['top_5_eval_accuracy'] = tf.metrics.mean(in_top_5)
      eval_metrics['cross_loss'] = tf.metrics.mean(cross_loss)
      eval_metrics['reg_loss'] = tf.metrics.mean(reg_loss)
      eval_metrics['eval_accuracy'] = tf.metrics.accuracy(
          labels=labels, predictions=predictions)

      # If evaluating once lets also calculate sparsities.
      if FLAGS.mode == 'eval_once':
        sparsity_summaries = utils.mask_summaries(pruning.get_masks())
        # We call mean on a scalar to create tensor, update_op pairs.
        sparsity_summaries = {k: tf.metrics.mean(v) for k, v
                              in sparsity_summaries.items()}
        eval_metrics.update(sparsity_summaries)
      return eval_metrics
    def _setup_graph(self,
                     default_sparsity,
                     mask_init_method,
                     custom_sparsity_map,
                     n_inp=3,
                     n_out=5):
        """Setups a trivial training procedure for sparse training."""
        tf.reset_default_graph()
        optim = tf.train.GradientDescentOptimizer(1e-3)
        sparse_optim = sparse_optimizers.SparseDNWOptimizer(
            optim,
            default_sparsity,
            mask_init_method,
            custom_sparsity_map=custom_sparsity_map)

        inp_values = np.arange(1, n_inp + 1)
        scale_vector_values = np.random.uniform(size=(n_out, )) - 0.5
        # The gradient is the outer product of input and the output gradients.
        # Since the loss is sample sum the output gradient is equal to the scale
        # vector.
        expected_grads = np.outer(inp_values, scale_vector_values)

        x = tf.reshape(tf.constant(inp_values, dtype=tf.float32), (1, n_inp))
        y = layers.masked_fully_connected(x, n_out, activation_fn=None)
        scale_vector = tf.constant(scale_vector_values, dtype=tf.float32)

        y = y * scale_vector
        loss = tf.reduce_sum(y)

        global_step = tf.train.get_or_create_global_step()
        grads_and_vars = sparse_optim.compute_gradients(loss)
        train_op = sparse_optim.apply_gradients(grads_and_vars,
                                                global_step=global_step)
        # Init
        sess = tf.Session()
        init = tf.global_variables_initializer()
        sess.run(init)
        mask = pruning.get_masks()[0]
        weights = pruning.get_weights()[0]
        return (sess, train_op, (expected_grads, grads_and_vars), mask,
                weights)
Beispiel #8
0
def main(unused_args):
    tf.set_random_seed(FLAGS.seed)
    tf.get_variable_scope().set_use_resource(True)
    np.random.seed(FLAGS.seed)

    # Load the MNIST data and set up an iterator.
    mnist_data = input_data.read_data_sets(FLAGS.mnist,
                                           one_hot=False,
                                           validation_size=0)
    train_images = mnist_data.train.images
    test_images = mnist_data.test.images
    if FLAGS.input_mask_path:
        reader = tf.train.load_checkpoint(FLAGS.input_mask_path)
        input_mask = reader.get_tensor('layer1/mask')
        indices = np.sum(input_mask, axis=1) != 0
        train_images = train_images[:, indices]
        test_images = test_images[:, indices]
    dataset = tf.data.Dataset.from_tensor_slices(
        (train_images, mnist_data.train.labels.astype(np.int32)))
    num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size
    dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0])
    batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size)
    iterator = batched_dataset.make_one_shot_iterator()

    test_dataset = tf.data.Dataset.from_tensor_slices(
        (test_images, mnist_data.test.labels.astype(np.int32)))
    num_test_images = mnist_data.test.images.shape[0]
    test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images)
    test_iterator = test_dataset.make_one_shot_iterator()

    # Set up loss function.
    use_model_pruning = FLAGS.training_method != 'baseline'

    if FLAGS.network_type == 'fc':
        cross_entropy_train, _ = mnist_network_fc(
            iterator.get_next(), model_pruning=use_model_pruning)
        cross_entropy_test, accuracy_test = mnist_network_fc(
            test_iterator.get_next(),
            reuse=True,
            model_pruning=use_model_pruning)
    else:
        raise RuntimeError(FLAGS.network + ' is an unknown network type.')

    # Remove extra added ones. Current implementation adds the variables twice
    # to the collection. Improve this hacky thing.
    # TODO test the following with the convnet or any other network.
    if use_model_pruning:
        for k in ('masks', 'masked_weights', 'thresholds', 'kernel'):
            # del tf.get_collection_ref(k)[2]
            # del tf.get_collection_ref(k)[2]
            collection = tf.get_collection_ref(k)
            del collection[len(collection) // 2:]
            print(tf.get_collection_ref(k))

    # Set up optimizer and update ops.
    global_step = tf.train.get_or_create_global_step()
    batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size

    if FLAGS.optimizer != 'adam':
        if not use_model_pruning:
            boundaries = [
                int(round(s * batch_per_epoch)) for s in [60, 70, 80]
            ]
        else:
            boundaries = [
                int(round(s * batch_per_epoch))
                for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20]
            ]
        learning_rate = tf.train.piecewise_constant(
            global_step,
            boundaries,
            values=[
                FLAGS.learning_rate / (3.**i)
                for i in range(len(boundaries) + 1)
            ])
    else:
        learning_rate = FLAGS.learning_rate

    if FLAGS.optimizer == 'adam':
        opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    elif FLAGS.optimizer == 'momentum':
        opt = tf.train.MomentumOptimizer(learning_rate,
                                         FLAGS.momentum,
                                         use_nesterov=FLAGS.use_nesterov)
    elif FLAGS.optimizer == 'sgd':
        opt = tf.train.GradientDescentOptimizer(learning_rate)
    else:
        raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type')
    custom_sparsities = {
        'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale,
        'layer3': FLAGS.end_sparsity * 0
    }

    if FLAGS.training_method == 'set':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseSETOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'static':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseStaticOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'momentum':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseMomentumOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            momentum=FLAGS.s_momentum,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            grow_init=FLAGS.grow_init,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            use_tpu=False)
    elif FLAGS.training_method == 'rigl':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseRigLOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            initial_acc_scale=FLAGS.rigl_acc_scale,
            use_tpu=False)
    elif FLAGS.training_method == 'snip':
        opt = sparse_optimizers.SparseSnipOptimizer(
            opt,
            mask_init_method=FLAGS.mask_init_method,
            default_sparsity=FLAGS.end_sparsity,
            custom_sparsity_map=custom_sparsities,
            use_tpu=False)
    elif FLAGS.training_method in ('scratch', 'baseline', 'prune'):
        pass
    else:
        raise ValueError('Unsupported pruning method: %s' %
                         FLAGS.training_method)

    train_op = opt.minimize(cross_entropy_train, global_step=global_step)

    if FLAGS.training_method == 'prune':
        hparams_string = (
            'begin_pruning_step={0},sparsity_function_begin_step={0},'
            'end_pruning_step={1},sparsity_function_end_step={1},'
            'target_sparsity={2},pruning_frequency={3},'
            'threshold_decay={4}'.format(FLAGS.prune_begin_step,
                                         FLAGS.prune_end_step,
                                         FLAGS.end_sparsity,
                                         FLAGS.pruning_frequency,
                                         FLAGS.threshold_decay))
        pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)
        pruning_hparams.set_hparam(
            'weight_sparsity_map',
            ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()])
        print(pruning_hparams)
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
        with tf.control_dependencies([train_op]):
            train_op = pruning_obj.conditional_mask_update_op()
    weight_sparsity_levels = pruning.get_weight_sparsity()
    global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks())
    tf.summary.scalar('test_accuracy', accuracy_test)
    tf.summary.scalar('global_sparsity', global_sparsity)
    for k, v in zip(pruning.get_masks(), weight_sparsity_levels):
        tf.summary.scalar('sparsity/%s' % k.name, v)
    if FLAGS.training_method in ('prune', 'snip', 'baseline'):
        mask_init_op = tf.no_op()
        tf.logging.info('No mask is set, starting dense.')
    else:
        all_masks = pruning.get_masks()
        mask_init_op = sparse_utils.get_mask_init_fn(all_masks,
                                                     FLAGS.mask_init_method,
                                                     FLAGS.end_sparsity,
                                                     custom_sparsities)

    if FLAGS.save_model:
        saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    hyper_params_string = '_'.join([
        FLAGS.network_type,
        str(FLAGS.batch_size),
        str(FLAGS.learning_rate),
        str(FLAGS.momentum), FLAGS.optimizer,
        str(FLAGS.l2_scale), FLAGS.training_method,
        str(FLAGS.prune_begin_step),
        str(FLAGS.prune_end_step),
        str(FLAGS.end_sparsity),
        str(FLAGS.pruning_frequency),
        str(FLAGS.seed)
    ])
    tf.io.gfile.makedirs(FLAGS.save_path)
    filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt')
    merged_summary_op = tf.summary.merge_all()

    # Run session.
    if not use_model_pruning:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy')
            sess.run([init_op])
            tic = time.time()
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                for i in range(FLAGS.num_epochs * num_batches):
                    sess.run([train_op])

                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %.4f, %.4f, %.4f' % (
                            i // num_batches, epoch_time, loss, accuracy)
                        print(log_str)
                        print(log_str, file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
    else:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            log_str = ','.join([
                'Epoch', 'Iteration', 'Test loss', 'Test accuracy',
                'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1'
            ])
            sess.run(init_op)
            sess.run(mask_init_op)
            tic = time.time()
            mask_records = {}
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                print(log_str)
                print(log_str, file=outputfile)
                for i in range(FLAGS.num_epochs * num_batches):
                    if (FLAGS.mask_record_frequency > 0
                            and i % FLAGS.mask_record_frequency == 0):
                        mask_vals = sess.run(pruning.get_masks())
                        # Cast into bool to save space.
                        mask_records[i] = [
                            a.astype(np.bool) for a in mask_vals
                        ]
                    sess.run([train_op])
                    weight_sparsity, global_sparsity_val = sess.run(
                        [weight_sparsity_levels, global_sparsity])
                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % (
                            i // num_batches, i, loss, accuracy,
                            global_sparsity_val, weight_sparsity[0],
                            weight_sparsity[1])
                        print(log_str)
                        print(log_str, file=outputfile)
                        mask_vals = sess.run(pruning.get_masks())
                        if FLAGS.network_type == 'fc':
                            sparsities, sizes = get_compressed_fc(mask_vals)
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes))
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes),
                                  file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
            if mask_records:
                np.save(os.path.join(FLAGS.save_path, 'mask_records'),
                        mask_records)
Beispiel #9
0
def train_function(training_method, loss, cross_loss, reg_loss, output_dir,
                   use_tpu):
    """Training script for resnet model.

  Args:
   training_method: string indicating pruning method used to compress model.
   loss: tensor float32 of the cross entropy + regularization losses.
   cross_loss: tensor, only cross entropy loss, passed for logging.
   reg_loss: tensor, only regularization loss, passed for logging.
   output_dir: string tensor indicating the directory to save summaries.
   use_tpu: boolean indicating whether to run script on a tpu.

  Returns:
    host_call: summary tensors to be computed at each training step.
    train_op: the optimization term.
  """

    global_step = tf.train.get_global_step()

    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
    learning_rate = lr_schedule(current_epoch)
    if FLAGS.use_adam:
        # We don't use step decrease for the learning rate.
        learning_rate = FLAGS.base_learning_rate * (FLAGS.train_batch_size /
                                                    256.0)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    else:
        optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                               momentum=FLAGS.momentum,
                                               use_nesterov=True)

    if use_tpu:
        # use CrossShardOptimizer when using TPU.
        optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

    if training_method == 'set':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseSETOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            stateless_seed_offset=FLAGS.seed)
    elif training_method == 'static':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseStaticOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            stateless_seed_offset=FLAGS.seed)
    elif training_method == 'momentum':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseMomentumOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            momentum=FLAGS.s_momentum,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            grow_init=FLAGS.grow_init,
            stateless_seed_offset=FLAGS.seed,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            use_tpu=use_tpu)
    elif training_method == 'rigl':
        # We override the train op to also update the mask.
        optimizer = sparse_optimizers.SparseRigLOptimizer(
            optimizer,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            stateless_seed_offset=FLAGS.seed,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            initial_acc_scale=FLAGS.rigl_acc_scale,
            use_tpu=use_tpu)
    elif training_method == 'snip':
        optimizer = sparse_optimizers.SparseSnipOptimizer(
            optimizer,
            mask_init_method=FLAGS.mask_init_method,
            custom_sparsity_map=CUSTOM_SPARSITY_MAP,
            default_sparsity=FLAGS.end_sparsity,
            use_tpu=use_tpu)
    elif training_method in ('scratch', 'baseline'):
        pass
    else:
        raise ValueError('Unsupported pruning method: %s' %
                         FLAGS.training_method)
    # UPDATE_OPS needs to be added as a dependency due to batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops), tf.name_scope('train'):
        train_op = optimizer.minimize(loss, global_step)

    metrics = {
        'global_step': tf.train.get_or_create_global_step(),
        'loss': loss,
        'cross_loss': cross_loss,
        'reg_loss': reg_loss,
        'learning_rate': learning_rate,
        'current_epoch': current_epoch,
    }

    # Logging drop_fraction if dynamic sparse training.
    if training_method in ('set', 'momentum', 'rigl', 'static'):
        metrics['drop_fraction'] = optimizer.drop_fraction

    # Let's log some statistics from a single parameter-mask couple.
    # This is useful for debugging.
    test_var = pruning.get_weights()[0]
    test_var_mask = pruning.get_masks()[0]
    metrics.update({
        'fw_nz_weight': tf.count_nonzero(test_var),
        'fw_nz_mask': tf.count_nonzero(test_var_mask),
        'fw_l1_weight': tf.reduce_sum(tf.abs(test_var))
    })

    masks = pruning.get_masks()
    global_sparsity = sparse_utils.calculate_sparsity(masks)
    metrics['global_sparsity'] = global_sparsity
    metrics.update(
        utils.mask_summaries(masks[:4] + masks[-1:],
                             with_img=FLAGS.log_mask_imgs_each_iteration))

    host_call = (functools.partial(utils.host_call_fn,
                                   output_dir), utils.format_tensors(metrics))

    return host_call, train_op
Beispiel #10
0
def train_function(pruning_method, loss, output_dir, use_tpu):
    """Training script for resnet model.

  Args:
   pruning_method: string indicating pruning method used to compress model.
   loss: tensor float32 of the cross entropy + regularization losses.
   output_dir: string tensor indicating the directory to save summaries.
   use_tpu: boolean indicating whether to run script on a tpu.

  Returns:
    host_call: summary tensors to be computed at each training step.
    train_op: the optimization term.
  """

    global_step = tf.train.get_global_step()

    steps_per_epoch = FLAGS.num_train_images / FLAGS.train_batch_size
    current_epoch = (tf.cast(global_step, tf.float32) / steps_per_epoch)
    learning_rate = lr_schedule(current_epoch)
    optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate,
                                           momentum=FLAGS.momentum,
                                           use_nesterov=True)

    if use_tpu:
        # use CrossShardOptimizer when using TPU.
        optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

    # UPDATE_OPS needs to be added as a dependency due to batch norm
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops), tf.name_scope('train'):
        train_op = optimizer.minimize(loss, global_step)

    if not use_tpu:
        if FLAGS.num_workers > 0:
            optimizer = tf.train.SyncReplicasOptimizer(
                optimizer,
                replicas_to_aggregate=FLAGS.num_workers,
                total_num_replicas=FLAGS.num_workers)
            optimizer.make_session_run_hook(True)

    metrics = {
        'global_step': tf.train.get_or_create_global_step(),
        'loss': loss,
        'learning_rate': learning_rate,
        'current_epoch': current_epoch
    }

    if pruning_method == 'threshold':
        # construct the necessary hparams string from the FLAGS
        hparams_string = ('begin_pruning_step={0},'
                          'sparsity_function_begin_step={0},'
                          'end_pruning_step={1},'
                          'sparsity_function_end_step={1},'
                          'target_sparsity={2},'
                          'pruning_frequency={3},'
                          'threshold_decay=0,'
                          'use_tpu={4}'.format(
                              FLAGS.sparsity_begin_step,
                              FLAGS.sparsity_end_step,
                              FLAGS.end_sparsity,
                              FLAGS.pruning_frequency,
                              FLAGS.use_tpu,
                          ))

        # Parse pruning hyperparameters
        pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)

        # The first layer has so few parameters, we don't need to prune it, and
        # pruning it a higher sparsity levels has very negative effects.
        if FLAGS.prune_first_layer and FLAGS.first_layer_sparsity >= 0.:
            pruning_hparams.set_hparam(
                'weight_sparsity_map',
                ['resnet_model/initial_conv:%f' % FLAGS.first_layer_sparsity])
        if FLAGS.prune_last_layer and FLAGS.last_layer_sparsity >= 0:
            pruning_hparams.set_hparam(
                'weight_sparsity_map',
                ['resnet_model/final_dense:%f' % FLAGS.last_layer_sparsity])

        # Create a pruning object using the pruning hyperparameters
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

        # We override the train op to also update the mask.
        with tf.control_dependencies([train_op]):
            train_op = pruning_obj.conditional_mask_update_op()

        masks = pruning.get_masks()
        metrics.update(utils.mask_summaries(masks))
    elif pruning_method == 'scratch':
        masks = pruning.get_masks()
        # make sure the masks have the sparsity we expect and that it doesn't change
        metrics.update(utils.mask_summaries(masks))
    elif pruning_method == 'variational_dropout':
        masks = utils.add_vd_pruning_summaries(
            threshold=FLAGS.log_alpha_threshold)
        metrics.update(masks)
    elif pruning_method == 'l0_regularization':
        summaries = utils.add_l0_summaries()
        metrics.update(summaries)
    elif pruning_method == 'baseline':
        pass
    else:
        raise ValueError('Unsupported pruning method', FLAGS.pruning_method)

    host_call = (functools.partial(utils.host_call_fn,
                                   output_dir), utils.format_tensors(metrics))

    return host_call, train_op
Beispiel #11
0
 def get_masks(self):
     return pruning.get_masks()
Beispiel #12
0
def wide_resnet_w_pruning(features, labels, mode, params):
  """The model_fn for ResNet wide with pruning.

  Args:
    features: A float32 batch of images.
    labels: A int32 batch of labels.
    mode: Specifies whether training or evaluation.
    params: Dictionary of parameters passed to the model.

  Returns:
    A EstimatorSpec for the model

  Raises:
      ValueError: if mode is not recognized as train or eval.
  """

  if isinstance(features, dict):
    features = features['feature']

  train_dir = params['train_dir']
  training_method = params['training_method']

  global_step, accuracy, top_5_accuracy, logits = build_model(
      mode=mode,
      images=features,
      labels=labels,
      training_method=training_method,
      num_classes=FLAGS.num_classes,
      depth=FLAGS.resnet_depth,
      width=FLAGS.resnet_width)

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        'classes': tf.argmax(logits, axis=1),
        'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
    }
    return tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=predictions,
        export_outputs={
            'classify': tf.estimator.export.PredictOutput(predictions)
        })

  with tf.name_scope('computing_cross_entropy_loss'):
    entropy_loss = tf.losses.sparse_softmax_cross_entropy(
        labels=labels, logits=logits)
    tf.summary.scalar('cross_entropy_loss', entropy_loss)

  with tf.name_scope('computing_total_loss'):
    total_loss = tf.losses.get_total_loss(add_regularization_losses=True)

  if mode == tf.estimator.ModeKeys.TRAIN:
    hooks, eval_metrics, train_op = train_fn(training_method, global_step,
                                             total_loss, train_dir, accuracy,
                                             top_5_accuracy)
  elif mode == tf.estimator.ModeKeys.EVAL:
    hooks = None
    train_op = None
    with tf.name_scope('summaries'):
      eval_metrics = create_eval_metrics(labels, logits)
  else:
    raise ValueError('mode not recognized as training or eval.')

  if FLAGS.training_method in ('prune', 'snip', 'baseline'):
    scaffold = None
    tf.logging.info('No mask is set, starting dense.')
  else:
    all_masks = pruning.get_masks()
    assigner = sparse_utils.get_mask_init_fn(
        all_masks, FLAGS.mask_init_method, FLAGS.end_sparsity, {})
    def init_fn(scaffold, session):
      """A callable for restoring variable from a checkpoint."""
      del scaffold  # Unused.
      session.run(assigner)
    scaffold = tf.train.Scaffold(init_fn=init_fn)

  return tf.estimator.EstimatorSpec(
      mode=mode,
      training_hooks=hooks,
      loss=total_loss,
      train_op=train_op,
      eval_metric_ops=eval_metrics,
      scaffold=scaffold)
Beispiel #13
0
def train_fn(training_method, global_step, total_loss, train_dir, accuracy,
             top_5_accuracy):
  """Training script for resnet model.

  Args:
   training_method: specifies the method used to sparsify networks.
   global_step: the current step of training/eval.
   total_loss: tensor float32 of the cross entropy + regularization losses.
   train_dir: string specifying where directory where summaries are saved.
   accuracy: tensor float32 batch classification accuracy.
   top_5_accuracy: tensor float32 batch classification accuracy (top_5 classes).

  Returns:
    hooks: summary tensors to be computed at each training step.
    eval_metrics: set to None during training.
    train_op: the optimization term.
  """
  # Rougly drops at every 30k steps.
  boundaries = [30000, 60000, 90000]
  if FLAGS.training_steps_multiplier != 1.0:
    multiplier = FLAGS.training_steps_multiplier
    boundaries = [int(x * multiplier) for x in boundaries]
    tf.logging.info(
        'Learning Rate boundaries are updated with multiplier:%.2f', multiplier)

  learning_rate = tf.train.piecewise_constant(
      global_step,
      boundaries,
      values=[0.1 / (5.**i) for i in range(len(boundaries) + 1)],
      name='lr_schedule')

  optimizer = tf.train.MomentumOptimizer(
      learning_rate, momentum=FLAGS.momentum, use_nesterov=True)

  if training_method == 'set':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseSETOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal)
  elif training_method == 'static':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseStaticOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal)
  elif training_method == 'momentum':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseMomentumOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, momentum=FLAGS.s_momentum,
        frequency=FLAGS.maskupdate_frequency, drop_fraction=FLAGS.drop_fraction,
        grow_init=FLAGS.grow_init,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal, use_tpu=False)
  elif training_method == 'rigl':
    # We override the train op to also update the mask.
    optimizer = sparse_optimizers.SparseRigLOptimizer(
        optimizer, begin_step=FLAGS.maskupdate_begin_step,
        end_step=FLAGS.maskupdate_end_step, grow_init=FLAGS.grow_init,
        frequency=FLAGS.maskupdate_frequency,
        drop_fraction=FLAGS.drop_fraction,
        drop_fraction_anneal=FLAGS.drop_fraction_anneal,
        initial_acc_scale=FLAGS.rigl_acc_scale, use_tpu=False)
  elif training_method == 'snip':
    optimizer = sparse_optimizers.SparseSnipOptimizer(
        optimizer, mask_init_method=FLAGS.mask_init_method,
        default_sparsity=FLAGS.end_sparsity, use_tpu=False)
  elif training_method in ('scratch', 'baseline', 'prune'):
    pass
  else:
    raise ValueError('Unsupported pruning method: %s' % FLAGS.training_method)
  # Create the training op
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(total_loss, global_step)

  if training_method == 'prune':
    # construct the necessary hparams string from the FLAGS
    hparams_string = ('begin_pruning_step={0},'
                      'sparsity_function_begin_step={0},'
                      'end_pruning_step={1},'
                      'sparsity_function_end_step={1},'
                      'target_sparsity={2},'
                      'pruning_frequency={3},'
                      'threshold_decay=0,'
                      'use_tpu={4}'.format(
                          FLAGS.sparsity_begin_step,
                          FLAGS.sparsity_end_step,
                          FLAGS.end_sparsity,
                          FLAGS.pruning_frequency,
                          False,
                      ))
    # Parse pruning hyperparameters
    pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)

    # Create a pruning object using the pruning hyperparameters
    pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)

    tf.logging.info('starting mask update op')

    # We override the train op to also update the mask.
    with tf.control_dependencies([train_op]):
      train_op = pruning_obj.conditional_mask_update_op()

  masks = pruning.get_masks()
  mask_metrics = utils.mask_summaries(masks)
  for name, tensor in mask_metrics.items():
    tf.summary.scalar(name, tensor)

  tf.summary.scalar('learning_rate', learning_rate)
  tf.summary.scalar('accuracy', accuracy)
  tf.summary.scalar('total_loss', total_loss)
  tf.summary.scalar('top_5_accuracy', top_5_accuracy)
  # Logging drop_fraction if dynamic sparse training.
  if training_method in ('set', 'momentum', 'rigl', 'static'):
    tf.summary.scalar('drop_fraction', optimizer.drop_fraction)

  summary_op = tf.summary.merge_all()
  summary_hook = tf.train.SummarySaverHook(
      save_secs=300, output_dir=train_dir, summary_op=summary_op)
  hooks = [summary_hook]
  eval_metrics = None

  return hooks, eval_metrics, train_op