def testTaskWithUnstructuredSparsity(self, config_name):
    config = exp_factory.get_exp_config(config_name)
    config.task.train_data.global_batch_size = 2

    task = img_cls_task.ImageClassificationTask(config.task)
    model = task.build_model()

    metrics = task.build_metrics()
    strategy = tf.distribute.get_strategy()

    dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
                                                   config.task.train_data)

    iterator = iter(dataset)
    opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
    optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())

    if isinstance(optimizer, optimization.ExponentialMovingAverage
                 ) and not optimizer.has_shadow_copy:
      optimizer.shadow_copy(model)

    if config.task.pruning:
      # This is an auxilary initialization required to prune a model which is
      # originally done in the train library.
      actions.PruningAction(
          export_dir=tempfile.gettempdir(), model=model, optimizer=optimizer)

    # Check all layers and target weights are successfully pruned.
    self._validate_model_pruned(model, config_name)

    logs = task.train_step(next(iterator), model, optimizer, metrics=metrics)
    self._validate_metrics(logs, metrics)

    logs = task.validation_step(next(iterator), model, metrics=metrics)
    self._validate_metrics(logs, metrics)
    def testTaskWithStructuredSparsity(self, config_name):
        test_tfrecord_file = os.path.join(self.get_temp_dir(),
                                          'cls_test.tfrecord')
        self._create_test_tfrecord(test_tfrecord_file=test_tfrecord_file,
                                   num_samples=10,
                                   input_image_size=[224, 224])
        config = exp_factory.get_exp_config(config_name)
        config.task.train_data.global_batch_size = 2
        config.task.validation_data.input_path = test_tfrecord_file
        config.task.train_data.input_path = test_tfrecord_file

        # Add structured sparsity
        config.task.pruning.sparsity_m_by_n = (2, 4)
        config.task.pruning.frequency = 1

        task = img_cls_task.ImageClassificationTask(config.task)
        model = task.build_model()

        metrics = task.build_metrics()
        strategy = tf.distribute.get_strategy()

        dataset = orbit.utils.make_distributed_dataset(strategy,
                                                       task.build_inputs,
                                                       config.task.train_data)

        iterator = iter(dataset)
        opt_factory = optimization.OptimizerFactory(
            config.trainer.optimizer_config)
        optimizer = opt_factory.build_optimizer(
            opt_factory.build_learning_rate())

        if isinstance(optimizer, optimization.ExponentialMovingAverage
                      ) and not optimizer.has_shadow_copy:
            optimizer.shadow_copy(model)

        # This is an auxiliary initialization required to prune a model which is
        # originally done in the train library.
        pruning_actions = actions.PruningAction(
            export_dir=tempfile.gettempdir(), model=model, optimizer=optimizer)

        # Check all layers and target weights are successfully pruned.
        self._validate_model_pruned(model, config_name)

        logs = task.train_step(next(iterator),
                               model,
                               optimizer,
                               metrics=metrics)
        self._validate_metrics(logs, metrics)

        logs = task.validation_step(next(iterator), model, metrics=metrics)
        self._validate_metrics(logs, metrics)

        pruning_actions.update_pruning_step.on_epoch_end(batch=None)
        # Check whether the weights are pruned in 2x4 pattern.
        self._check_2x4_sparsity(model)