Exemple #1
0
    def testWeightSpecificSparsity(self):
        param_list = [
            "begin_pruning_step=1", "pruning_frequency=1",
            "end_pruning_step=100", "target_sparsity=0.5",
            "weight_sparsity_map=[layer2/weights:0.75]", "threshold_decay=0.0"
        ]
        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        with variable_scope.variable_scope("layer1"):
            w1 = variables.Variable(math_ops.linspace(1.0, 100.0, 100),
                                    name="weights")
            _ = pruning.apply_mask(w1)
        with variable_scope.variable_scope("layer2"):
            w2 = variables.Variable(math_ops.linspace(1.0, 100.0, 100),
                                    name="weights")
            _ = pruning.apply_mask(w2)

        p = pruning.Pruning(pruning_hparams)
        mask_update_op = p.conditional_mask_update_op()
        increment_global_step = state_ops.assign_add(self.global_step, 1)

        with self.test_session() as session:
            variables.global_variables_initializer().run()
            for _ in range(110):
                session.run(mask_update_op)
                session.run(increment_global_step)

            self.assertAllEqual(session.run(pruning.get_weight_sparsity()),
                                [0.5, 0.75])
Exemple #2
0
    def testPerLayerBlockSparsity(self):
        param_list = [
            "block_dims_map=[layer1/weights:1x1,layer2/weights:1x2]",
            "block_pooling_function=AVG", "threshold_decay=0.0"
        ]

        test_spec = ",".join(param_list)
        pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

        with variable_scope.variable_scope("layer1"):
            w1 = constant_op.constant([[-0.1, 0.1], [-0.2, 0.2]],
                                      name="weights")
            pruning.apply_mask(w1)

        with variable_scope.variable_scope("layer2"):
            w2 = constant_op.constant(
                [[0.1, 0.1, 0.3, 0.3], [0.2, 0.2, 0.4, 0.4]], name="weights")
            pruning.apply_mask(w2)

        sparsity = variables.VariableV1(0.5, name="sparsity")

        p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
        mask_update_op = p.mask_update_op()
        with self.cached_session() as session:
            variables.global_variables_initializer().run()
            session.run(mask_update_op)
            mask1_eval = session.run(pruning.get_masks()[0])
            mask2_eval = session.run(pruning.get_masks()[1])

            self.assertAllEqual(session.run(pruning.get_weight_sparsity()),
                                [0.5, 0.5])

            self.assertAllEqual(mask1_eval, [[0.0, 0.0], [1., 1.]])
            self.assertAllEqual(mask2_eval, [[0, 0, 1., 1.], [0, 0, 1., 1.]])
Exemple #3
0
  def testWeightSpecificSparsity(self):
    param_list = [
        "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100",
        "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]",
        "threshold_decay=0.0"
    ]
    test_spec = ",".join(param_list)
    pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)

    with variable_scope.variable_scope("layer1"):
      w1 = variables.Variable(
          math_ops.linspace(1.0, 100.0, 100), name="weights")
      _ = pruning.apply_mask(w1)
    with variable_scope.variable_scope("layer2"):
      w2 = variables.Variable(
          math_ops.linspace(1.0, 100.0, 100), name="weights")
      _ = pruning.apply_mask(w2)

    p = pruning.Pruning(pruning_hparams)
    mask_update_op = p.conditional_mask_update_op()
    increment_global_step = state_ops.assign_add(self.global_step, 1)

    with self.cached_session() as session:
      variables.global_variables_initializer().run()
      for _ in range(110):
        session.run(mask_update_op)
        session.run(increment_global_step)

      self.assertAllEqual(
          session.run(pruning.get_weight_sparsity()), [0.5, 0.75])
Exemple #4
0
 def __init__(self,
              input_size,
              output_size,
              model_path: str,
              momentum=0.9,
              reg_str=0.0005,
              scope='ConvNet',
              pruning_start=int(10e4),
              pruning_end=int(10e5),
              pruning_freq=int(10),
              sparsity_start=0,
              sparsity_end=int(10e5),
              target_sparsity=0.0,
              dropout=0.5,
              initial_sparsity=0,
              wd=0.0):
     super(ConvNet, self).__init__(input_size=input_size,
                                   output_size=output_size,
                                   model_path=model_path)
     self.scope = scope
     self.momentum = momentum
     self.reg_str = reg_str
     self.dropout = dropout
     self.logger = get_logger(scope)
     self.wd = wd
     self.logger.info("creating graph...")
     with self.graph.as_default():
         self.global_step = tf.Variable(0, trainable=False)
         self._build_placeholders()
         self.logits = self._build_model()
         self.weights_matrices = pruning.get_masked_weights()
         self.sparsity = pruning.get_weight_sparsity()
         self.loss = self._loss()
         self.train_op = self._optimizer()
         self._create_metrics()
         self.saver = tf.train.Saver(var_list=tf.global_variables())
         self.hparams = pruning.get_pruning_hparams()\
             .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},'
                    ' sparsity_function_begin_step={},sparsity_function_end_step={},'
                    'pruning_frequency={},initial_sparsity={},'
                    ' sparsity_function_exponent={}'.format(scope,
                                                            pruning_start,
                                                            pruning_end,
                                                            target_sparsity,
                                                            sparsity_start,
                                                            sparsity_end,
                                                            pruning_freq,
                                                            initial_sparsity,
                                                            3))
         # note that the global step plays an important part in the pruning mechanism,
         # the higher the global step the closer the sparsity is to sparsity end
         self.pruning_obj = pruning.Pruning(self.hparams,
                                            global_step=self.global_step)
         self.mask_update_op = self.pruning_obj.conditional_mask_update_op()
         # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned
         # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP
         self.init_variables(
             tf.global_variables())  # initialize variables in graph
Exemple #5
0
 def __init__(self,
              actor_input_dim,
              actor_output_dim,
              model_path,
              redundancy=None,
              last_measure=10e4,
              tau=0.01):
     super(StudentActor, self).__init__(model_path=model_path)
     self.actor_input_dim = (None, actor_input_dim)
     self.actor_output_dim = (None, actor_output_dim)
     self.tau = tau
     self.redundancy = redundancy
     self.last_measure = last_measure
     with self.graph.as_default():
         self.actor_global_step = tf.Variable(0, trainable=False)
         self._build_placeholders()
         self.actor_logits = self._build_actor()
         # self.gumbel_dist = self._build_gumbel(self.actor_logits)
         self.loss = self._build_loss()
         self.actor_parameters = tf.get_collection(
             tf.GraphKeys.TRAINABLE_VARIABLES, scope='actor')
         self.actor_pruned_weight_matrices = pruning.get_masked_weights()
         self.actor_train_op = self._build_actor_train_op()
         self.actor_saver = tf.train.Saver(var_list=self.actor_parameters,
                                           max_to_keep=100)
         self.init_variables(tf.global_variables())
         self.sparsity = pruning.get_weight_sparsity()
         self.hparams = pruning.get_pruning_hparams() \
             .parse('name={}, begin_pruning_step={}, end_pruning_step={}, target_sparsity={},'
                    ' sparsity_function_begin_step={},sparsity_function_end_step={},'
                    'pruning_frequency={},initial_sparsity={},'
                    ' sparsity_function_exponent={}'.format('Actor',
                                                            cfg.pruning_start,
                                                            cfg.pruning_end,
                                                            cfg.target_sparsity,
                                                            cfg.sparsity_start,
                                                            cfg.sparsity_end,
                                                            cfg.pruning_freq,
                                                            cfg.initial_sparsity,
                                                            3))
         # note that the global step plays an important part in the pruning mechanism,
         # the higher the global step the closer the sparsity is to sparsity end
         self.pruning_obj = pruning.Pruning(
             self.hparams, global_step=self.actor_global_step)
         self.mask_update_op = self.pruning_obj.conditional_mask_update_op()
         # the pruning objects defines the pruning mechanism, via the mask_update_op the model gets pruned
         # the pruning takes place at each training epoch and it objective to achieve the sparsity end HP
         self.init_variables(
             tf.global_variables())  # initialize variables in graph
def main(unused_args):
    tf.set_random_seed(FLAGS.seed)
    tf.get_variable_scope().set_use_resource(True)
    np.random.seed(FLAGS.seed)

    # Load the MNIST data and set up an iterator.
    mnist_data = input_data.read_data_sets(FLAGS.mnist,
                                           one_hot=False,
                                           validation_size=0)
    train_images = mnist_data.train.images
    test_images = mnist_data.test.images
    if FLAGS.input_mask_path:
        reader = tf.train.load_checkpoint(FLAGS.input_mask_path)
        input_mask = reader.get_tensor('layer1/mask')
        indices = np.sum(input_mask, axis=1) != 0
        train_images = train_images[:, indices]
        test_images = test_images[:, indices]
    dataset = tf.data.Dataset.from_tensor_slices(
        (train_images, mnist_data.train.labels.astype(np.int32)))
    num_batches = mnist_data.train.images.shape[0] // FLAGS.batch_size
    dataset = dataset.shuffle(buffer_size=mnist_data.train.images.shape[0])
    batched_dataset = dataset.repeat(FLAGS.num_epochs).batch(FLAGS.batch_size)
    iterator = batched_dataset.make_one_shot_iterator()

    test_dataset = tf.data.Dataset.from_tensor_slices(
        (test_images, mnist_data.test.labels.astype(np.int32)))
    num_test_images = mnist_data.test.images.shape[0]
    test_dataset = test_dataset.repeat(FLAGS.num_epochs).batch(num_test_images)
    test_iterator = test_dataset.make_one_shot_iterator()

    # Set up loss function.
    use_model_pruning = FLAGS.training_method != 'baseline'

    if FLAGS.network_type == 'fc':
        cross_entropy_train, _ = mnist_network_fc(
            iterator.get_next(), model_pruning=use_model_pruning)
        cross_entropy_test, accuracy_test = mnist_network_fc(
            test_iterator.get_next(),
            reuse=True,
            model_pruning=use_model_pruning)
    else:
        raise RuntimeError(FLAGS.network + ' is an unknown network type.')

    # Remove extra added ones. Current implementation adds the variables twice
    # to the collection. Improve this hacky thing.
    # TODO test the following with the convnet or any other network.
    if use_model_pruning:
        for k in ('masks', 'masked_weights', 'thresholds', 'kernel'):
            # del tf.get_collection_ref(k)[2]
            # del tf.get_collection_ref(k)[2]
            collection = tf.get_collection_ref(k)
            del collection[len(collection) // 2:]
            print(tf.get_collection_ref(k))

    # Set up optimizer and update ops.
    global_step = tf.train.get_or_create_global_step()
    batch_per_epoch = mnist_data.train.images.shape[0] // FLAGS.batch_size

    if FLAGS.optimizer != 'adam':
        if not use_model_pruning:
            boundaries = [
                int(round(s * batch_per_epoch)) for s in [60, 70, 80]
            ]
        else:
            boundaries = [
                int(round(s * batch_per_epoch))
                for s in [FLAGS.lr_drop_epoch, FLAGS.lr_drop_epoch + 20]
            ]
        learning_rate = tf.train.piecewise_constant(
            global_step,
            boundaries,
            values=[
                FLAGS.learning_rate / (3.**i)
                for i in range(len(boundaries) + 1)
            ])
    else:
        learning_rate = FLAGS.learning_rate

    if FLAGS.optimizer == 'adam':
        opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    elif FLAGS.optimizer == 'momentum':
        opt = tf.train.MomentumOptimizer(learning_rate,
                                         FLAGS.momentum,
                                         use_nesterov=FLAGS.use_nesterov)
    elif FLAGS.optimizer == 'sgd':
        opt = tf.train.GradientDescentOptimizer(learning_rate)
    else:
        raise RuntimeError(FLAGS.optimizer + ' is unknown optimizer type')
    custom_sparsities = {
        'layer2': FLAGS.end_sparsity * FLAGS.sparsity_scale,
        'layer3': FLAGS.end_sparsity * 0
    }

    if FLAGS.training_method == 'set':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseSETOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'static':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseStaticOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal)
    elif FLAGS.training_method == 'momentum':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseMomentumOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            momentum=FLAGS.s_momentum,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            grow_init=FLAGS.grow_init,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            use_tpu=False)
    elif FLAGS.training_method == 'rigl':
        # We override the train op to also update the mask.
        opt = sparse_optimizers.SparseRigLOptimizer(
            opt,
            begin_step=FLAGS.maskupdate_begin_step,
            end_step=FLAGS.maskupdate_end_step,
            grow_init=FLAGS.grow_init,
            frequency=FLAGS.maskupdate_frequency,
            drop_fraction=FLAGS.drop_fraction,
            drop_fraction_anneal=FLAGS.drop_fraction_anneal,
            initial_acc_scale=FLAGS.rigl_acc_scale,
            use_tpu=False)
    elif FLAGS.training_method == 'snip':
        opt = sparse_optimizers.SparseSnipOptimizer(
            opt,
            mask_init_method=FLAGS.mask_init_method,
            default_sparsity=FLAGS.end_sparsity,
            custom_sparsity_map=custom_sparsities,
            use_tpu=False)
    elif FLAGS.training_method in ('scratch', 'baseline', 'prune'):
        pass
    else:
        raise ValueError('Unsupported pruning method: %s' %
                         FLAGS.training_method)

    train_op = opt.minimize(cross_entropy_train, global_step=global_step)

    if FLAGS.training_method == 'prune':
        hparams_string = (
            'begin_pruning_step={0},sparsity_function_begin_step={0},'
            'end_pruning_step={1},sparsity_function_end_step={1},'
            'target_sparsity={2},pruning_frequency={3},'
            'threshold_decay={4}'.format(FLAGS.prune_begin_step,
                                         FLAGS.prune_end_step,
                                         FLAGS.end_sparsity,
                                         FLAGS.pruning_frequency,
                                         FLAGS.threshold_decay))
        pruning_hparams = pruning.get_pruning_hparams().parse(hparams_string)
        pruning_hparams.set_hparam(
            'weight_sparsity_map',
            ['{0}:{1}'.format(k, v) for k, v in custom_sparsities.items()])
        print(pruning_hparams)
        pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step)
        with tf.control_dependencies([train_op]):
            train_op = pruning_obj.conditional_mask_update_op()
    weight_sparsity_levels = pruning.get_weight_sparsity()
    global_sparsity = sparse_utils.calculate_sparsity(pruning.get_masks())
    tf.summary.scalar('test_accuracy', accuracy_test)
    tf.summary.scalar('global_sparsity', global_sparsity)
    for k, v in zip(pruning.get_masks(), weight_sparsity_levels):
        tf.summary.scalar('sparsity/%s' % k.name, v)
    if FLAGS.training_method in ('prune', 'snip', 'baseline'):
        mask_init_op = tf.no_op()
        tf.logging.info('No mask is set, starting dense.')
    else:
        all_masks = pruning.get_masks()
        mask_init_op = sparse_utils.get_mask_init_fn(all_masks,
                                                     FLAGS.mask_init_method,
                                                     FLAGS.end_sparsity,
                                                     custom_sparsities)

    if FLAGS.save_model:
        saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    hyper_params_string = '_'.join([
        FLAGS.network_type,
        str(FLAGS.batch_size),
        str(FLAGS.learning_rate),
        str(FLAGS.momentum), FLAGS.optimizer,
        str(FLAGS.l2_scale), FLAGS.training_method,
        str(FLAGS.prune_begin_step),
        str(FLAGS.prune_end_step),
        str(FLAGS.end_sparsity),
        str(FLAGS.pruning_frequency),
        str(FLAGS.seed)
    ])
    tf.io.gfile.makedirs(FLAGS.save_path)
    filename = os.path.join(FLAGS.save_path, hyper_params_string + '.txt')
    merged_summary_op = tf.summary.merge_all()

    # Run session.
    if not use_model_pruning:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            print('Epoch', 'Epoch time', 'Test loss', 'Test accuracy')
            sess.run([init_op])
            tic = time.time()
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                for i in range(FLAGS.num_epochs * num_batches):
                    sess.run([train_op])

                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %.4f, %.4f, %.4f' % (
                            i // num_batches, epoch_time, loss, accuracy)
                        print(log_str)
                        print(log_str, file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
    else:
        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                FLAGS.save_path, graph=tf.get_default_graph())
            log_str = ','.join([
                'Epoch', 'Iteration', 'Test loss', 'Test accuracy',
                'G_Sparsity', 'Sparsity Layer 0', 'Sparsity Layer 1'
            ])
            sess.run(init_op)
            sess.run(mask_init_op)
            tic = time.time()
            mask_records = {}
            with tf.io.gfile.GFile(filename, 'w') as outputfile:
                print(log_str)
                print(log_str, file=outputfile)
                for i in range(FLAGS.num_epochs * num_batches):
                    if (FLAGS.mask_record_frequency > 0
                            and i % FLAGS.mask_record_frequency == 0):
                        mask_vals = sess.run(pruning.get_masks())
                        # Cast into bool to save space.
                        mask_records[i] = [
                            a.astype(np.bool) for a in mask_vals
                        ]
                    sess.run([train_op])
                    weight_sparsity, global_sparsity_val = sess.run(
                        [weight_sparsity_levels, global_sparsity])
                    if (i % num_batches) == (-1 % num_batches):
                        epoch_time = time.time() - tic
                        loss, accuracy, summary = sess.run([
                            cross_entropy_test, accuracy_test,
                            merged_summary_op
                        ])
                        # Write logs at every test iteration.
                        summary_writer.add_summary(summary, i)
                        log_str = '%d, %d, %.4f, %.4f, %.4f, %.4f, %.4f' % (
                            i // num_batches, i, loss, accuracy,
                            global_sparsity_val, weight_sparsity[0],
                            weight_sparsity[1])
                        print(log_str)
                        print(log_str, file=outputfile)
                        mask_vals = sess.run(pruning.get_masks())
                        if FLAGS.network_type == 'fc':
                            sparsities, sizes = get_compressed_fc(mask_vals)
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes))
                            print('[COMPRESSED SPARSITIES/SHAPE]: %s %s' %
                                  (sparsities, sizes),
                                  file=outputfile)
                        tic = time.time()
            if FLAGS.save_model:
                saver.save(sess, os.path.join(FLAGS.save_path, 'model.ckpt'))
            if mask_records:
                np.save(os.path.join(FLAGS.save_path, 'mask_records'),
                        mask_records)