Exemple #1
0
def create_initialized_compressed_model(model: nn.Module, config: NNCFConfig,
                                        train_loader: DataLoader) -> nn.Module:
    config = register_default_init_args(deepcopy(config), train_loader,
                                        nn.MSELoss)
    model, _compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    return model
Exemple #2
0
def create_finetuned_lenet_model_and_dataloader(config,
                                                eval_fn,
                                                finetuning_steps,
                                                learning_rate=1e-3):
    with set_torch_seed():
        train_loader = create_ones_mock_dataloader(config, num_samples=10)
        model = LeNet()
        for param in model.parameters():
            nn.init.uniform_(param, a=0.0, b=0.01)

        data_loader = iter(train_loader)
        optimizer = SGD(model.parameters(), lr=learning_rate)
        for _ in range(finetuning_steps):
            optimizer.zero_grad()
            x, y_gt = next(data_loader)
            y = model(x)
            loss = F.mse_loss(y.sum(), y_gt)
            loss.backward()
            optimizer.step()

    config = register_default_init_args(
        config,
        train_loader=train_loader,
        model_eval_fn=partial(eval_fn, train_loader=train_loader))
    model, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)
    return model, train_loader, compression_ctrl
Exemple #3
0
def test_legr_class_setting_params(tmp_path):
    generations_ref = 150
    train_steps_ref = 50
    max_pruning_ref = 0.1

    model = PruningTestModel()
    config = create_default_legr_config()
    config['compression']['params']['legr_params'] = {}
    config['compression']['params']['legr_params']['generations'] = generations_ref
    config['compression']['params']['legr_params']['train_steps'] = train_steps_ref
    config['compression']['params']['legr_params']['max_pruning'] = max_pruning_ref
    config['compression']['params']['legr_params']['random_seed'] = 1

    train_loader = create_ones_mock_dataloader(config)
    val_loader = create_ones_mock_dataloader(config)
    train_steps_fn = lambda *x: None
    validate_fn = lambda *x: (0, 0)
    nncf_config = register_default_init_args(config, train_loader=train_loader, train_steps_fn=train_steps_fn,
                                             val_loader=val_loader, validate_fn=validate_fn)
    _, compression_ctrl = create_compressed_model_and_algo_for_test(model, nncf_config)

    compression_ctrl.legr.num_generations = generations_ref
    compression_ctrl.legr.max_pruning = max_pruning_ref
    compression_ctrl.legr._train_steps = train_steps_ref
    compression_ctrl.legr.random_seed = 1
Exemple #4
0
def test_mock_dump_checkpoint(aa_config):
    is_called_dump_checkpoint_fn = False

    def mock_dump_checkpoint_fn(model, compression_controller,
                                accuracy_aware_runner, aa_log_dir):
        from nncf.api.compression import CompressionAlgorithmController
        from nncf.common.accuracy_aware_training.runner import TrainingRunner
        assert isinstance(model, torch.nn.Module)
        assert isinstance(compression_controller,
                          CompressionAlgorithmController)
        assert isinstance(accuracy_aware_runner, TrainingRunner)
        assert isinstance(aa_log_dir, str)
        nonlocal is_called_dump_checkpoint_fn
        is_called_dump_checkpoint_fn = True

    config = get_quantization_config_without_range_init(LeNet.INPUT_SIZE[-1])
    train_loader = create_ones_mock_dataloader(aa_config, num_samples=10)
    model = LeNet()
    config.update(aa_config)

    def train_fn(compression_ctrl,
                 model,
                 epoch,
                 optimizer,
                 lr_scheduler,
                 train_loader=train_loader):
        pass

    def mock_validate_fn(model, init_step=False, epoch=0):
        return 80

    def configure_optimizers_fn():
        optimizer = SGD(model.parameters(), lr=0.001)
        return optimizer, None

    config = register_default_init_args(config,
                                        train_loader=train_loader,
                                        model_eval_fn=partial(mock_validate_fn,
                                                              init_step=True))

    model, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)

    early_stopping_training_loop = EarlyExitCompressionTrainingLoop(
        config, compression_ctrl, dump_checkpoints=True)
    model = early_stopping_training_loop.run(
        model,
        train_epoch_fn=train_fn,
        validate_fn=partial(mock_validate_fn),
        configure_optimizers_fn=configure_optimizers_fn,
        dump_checkpoint_fn=mock_dump_checkpoint_fn)
    assert is_called_dump_checkpoint_fn
Exemple #5
0
def test_legr_class_default_params(tmp_path):
    model = PruningTestModel()
    config = create_default_legr_config()
    train_loader = create_ones_mock_dataloader(config)
    val_loader = create_ones_mock_dataloader(config)
    train_steps_fn = lambda *x: None
    validate_fn = lambda *x: (0, 0)
    nncf_config = register_default_init_args(config, train_loader=train_loader, train_steps_fn=train_steps_fn,
                                             val_loader=val_loader, validate_fn=validate_fn)
    _, compression_ctrl = create_compressed_model_and_algo_for_test(model, nncf_config)

    compression_ctrl.legr.num_generations = 400
    compression_ctrl.legr.max_pruning = 0.8
    compression_ctrl.legr._train_steps = 200
    compression_ctrl.legr.random_seed = 42
Exemple #6
0
def test_legr_reproducibility():
    np.random.seed(42)
    config = create_default_legr_config()

    train_loader = create_ones_mock_dataloader(config)
    val_loader = create_ones_mock_dataloader(config)
    train_steps_fn = lambda *x: None
    validate_fn = lambda *x: (0, np.random.random())
    nncf_config = register_default_init_args(config, train_loader=train_loader, train_steps_fn=train_steps_fn,
                                             val_loader=val_loader, validate_fn=validate_fn)
    model_1 = PruningTestModel()
    _, compression_ctrl_1 = create_compressed_model_and_algo_for_test(model_1, nncf_config)

    model_2 = PruningTestModel()
    _, compression_ctrl_2 = create_compressed_model_and_algo_for_test(model_2, config)

    assert compression_ctrl_1.ranking_coeffs == compression_ctrl_2.ranking_coeffs
def test_accuracy_aware_config(aa_config, must_raise):
    def mock_validate_fn(model):
        pass

    config = get_quantization_config_without_range_init(LeNet.INPUT_SIZE[-1])

    config.update({
        "accuracy_aware_training": {
            "mode": "adaptive_compression_level",
            "params": {
                "maximal_relative_accuracy_degradation": 1,
                "initial_training_phase_epochs": 1,
                "patience_epochs": 10
            }
        }
    })

    config.update(aa_config)

    train_loader = create_ones_mock_dataloader(config, num_samples=10)
    model = LeNet()

    config = register_default_init_args(config,
                                        train_loader=train_loader,
                                        model_eval_fn=mock_validate_fn)
    model, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)

    if must_raise:
        with pytest.raises(RuntimeError):
            _ = create_accuracy_aware_training_loop(config,
                                                    compression_ctrl,
                                                    dump_checkpoints=False)
    else:
        _ = create_accuracy_aware_training_loop(config,
                                                compression_ctrl,
                                                dump_checkpoints=False)
Exemple #8
0
def test_early_exit_with_mock_validation(max_accuracy_degradation,
                                         exit_epoch_number,
                                         maximal_total_epochs=100):
    epoch_counter = 0

    def mock_validate_fn(model, init_step=False, epoch=0):
        original_metric = 0.85
        if init_step:
            return original_metric
        nonlocal epoch_counter
        epoch_counter = epoch
        if "maximal_relative_accuracy_degradation" in max_accuracy_degradation:
            return original_metric * (1 - 0.01 * max_accuracy_degradation[
                'maximal_relative_accuracy_degradation']) * (epoch /
                                                             exit_epoch_number)
        return (original_metric - max_accuracy_degradation['maximal_absolute_accuracy_degradation']) * \
               epoch / exit_epoch_number

    config = get_quantization_config_without_range_init(LeNet.INPUT_SIZE[-1])

    params = {"maximal_total_epochs": maximal_total_epochs}
    params.update(max_accuracy_degradation)
    accuracy_aware_config = {
        "accuracy_aware_training": {
            "mode": "early_exit",
            "params": params
        }
    }

    config.update(accuracy_aware_config)

    train_loader = create_ones_mock_dataloader(config, num_samples=10)
    model = LeNet()

    config = register_default_init_args(config,
                                        train_loader=train_loader,
                                        model_eval_fn=partial(mock_validate_fn,
                                                              init_step=True))

    model, compression_ctrl = create_compressed_model_and_algo_for_test(
        model, config)

    def train_fn(compression_ctrl,
                 model,
                 epoch,
                 optimizer,
                 lr_scheduler,
                 train_loader=train_loader):
        pass

    def configure_optimizers_fn():
        return None, None

    early_stopping_training_loop = EarlyExitCompressionTrainingLoop(
        config, compression_ctrl, dump_checkpoints=False)
    model = early_stopping_training_loop.run(
        model,
        train_epoch_fn=train_fn,
        validate_fn=partial(mock_validate_fn),
        configure_optimizers_fn=configure_optimizers_fn)
    # Epoch number starts from 0
    assert epoch_counter == exit_epoch_number
Exemple #9
0
def wrap_nncf_model(model,
                    cfg,
                    checkpoint_dict=None,
                    datamanager_for_init=None):
    # Note that we require to import it here to avoid cyclic imports when import get_no_nncf_trace_context_manager
    # from mobilenetv3
    from torchreid.data.transforms import build_inference_transform

    from nncf import NNCFConfig
    from nncf.torch import create_compressed_model, load_state
    from nncf.torch.initialization import register_default_init_args
    from nncf.torch.dynamic_graph.io_handling import nncf_model_input
    from nncf.torch.dynamic_graph.trace_tensor import TracedTensor
    from nncf.torch.initialization import PTInitializingDataLoader

    if checkpoint_dict is None:
        checkpoint_path = cfg.model.load_weights
        resuming_checkpoint = safe_load_checkpoint(
            checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint_path = 'pretrained_dict'
        resuming_checkpoint = checkpoint_dict

    if datamanager_for_init is None and not is_nncf_state(resuming_checkpoint):
        raise RuntimeError('Either datamanager_for_init or NNCF pre-trained '
                           'model checkpoint should be set')

    nncf_metainfo = None
    if is_nncf_state(resuming_checkpoint):
        nncf_metainfo = _get_nncf_metainfo_from_state(resuming_checkpoint)
        nncf_config_data = nncf_metainfo['nncf_config']
        datamanager_for_init = None
        logger.info(f'Read NNCF metainfo with NNCF config from the checkpoint:'
                    f'nncf_metainfo=\n{pformat(nncf_metainfo)}')
    else:
        resuming_checkpoint = None
        nncf_config_data = cfg.get('nncf_config')

        if nncf_config_data is None:
            logger.info('Cannot read nncf_config from config file')
        else:
            logger.info(f' nncf_config=\n{pformat(nncf_config_data)}')

    h, w = cfg.data.height, cfg.data.width
    if not nncf_config_data:
        logger.info('Using the default NNCF int8 quantization config')
        nncf_config_data = get_default_nncf_compression_config(h, w)

    # do it even if nncf_config_data is loaded from a checkpoint -- for the rare case when
    # the width and height of the model's input was changed in the config
    # and then finetuning of NNCF model is run
    nncf_config_data.setdefault('input_info', {})
    nncf_config_data['input_info']['sample_size'] = [1, 3, h, w]

    nncf_config = NNCFConfig(nncf_config_data)
    logger.info(f'nncf_config =\n{pformat(nncf_config)}')
    if not nncf_metainfo:
        nncf_metainfo = create_nncf_metainfo(enable_quantization=True,
                                             enable_pruning=False,
                                             nncf_config=nncf_config_data)
    else:
        # update it just to be on the safe side
        nncf_metainfo['nncf_config'] = nncf_config_data

    class ReidInitializeDataLoader(PTInitializingDataLoader):
        def get_inputs(self, dataloader_output):
            # define own InitializingDataLoader class using approach like
            # parse_data_for_train and parse_data_for_eval in the class Engine
            # dataloader_output[0] should be image here
            args = (dataloader_output[0], )
            return args, {}

    @torch.no_grad()
    def model_eval_fn(model):
        """
        Runs evaluation of the model on the validation set and
        returns the target metric value.
        Used to evaluate the original model before compression
        if NNCF-based accuracy-aware training is used.
        """
        from torchreid.metrics.classification import evaluate_classification

        if test_loader is None:
            raise RuntimeError(
                'Cannot perform a model evaluation on the validation '
                'dataset since the validation data loader was not passed '
                'to wrap_nncf_model')

        model_type = get_model_attr(model, 'type')
        targets = list(test_loader.keys())
        use_gpu = cur_device.type == 'cuda'
        for dataset_name in targets:
            domain = 'source' if dataset_name in datamanager_for_init.sources else 'target'
            print(f'##### Evaluating {dataset_name} ({domain}) #####')
            if model_type == 'classification':
                cmc, _, _ = evaluate_classification(
                    test_loader[dataset_name]['query'], model, use_gpu=use_gpu)
                accuracy = cmc[0]
            elif model_type == 'multilabel':
                mAP, _, _, _, _, _, _ = evaluate_multilabel_classification(
                    test_loader[dataset_name]['query'], model, use_gpu=use_gpu)
                accuracy = mAP
            else:
                raise ValueError(
                    f'Cannot perform a model evaluation on the validation dataset'
                    f'since the model has unsupported model_type {model_type or "None"}'
                )

        return accuracy

    cur_device = next(model.parameters()).device
    logger.info(f'NNCF: cur_device = {cur_device}')

    if resuming_checkpoint is None:
        logger.info(
            'No NNCF checkpoint is provided -- register initialize data loader'
        )
        train_loader = datamanager_for_init.train_loader
        test_loader = datamanager_for_init.test_loader
        wrapped_loader = ReidInitializeDataLoader(train_loader)
        nncf_config = register_default_init_args(nncf_config,
                                                 wrapped_loader,
                                                 model_eval_fn=model_eval_fn,
                                                 device=cur_device)
        model_state_dict = None
        compression_state = None
    else:
        model_state_dict, compression_state = extract_model_and_compression_states(
            resuming_checkpoint)

    transform = build_inference_transform(
        cfg.data.height,
        cfg.data.width,
        norm_mean=cfg.data.norm_mean,
        norm_std=cfg.data.norm_std,
    )

    def dummy_forward(model):
        prev_training_state = model.training
        model.eval()
        input_img = random_image(cfg.data.height, cfg.data.width)
        input_blob = transform(input_img).unsqueeze(0)
        assert len(input_blob.size()) == 4
        input_blob = input_blob.to(device=cur_device)
        input_blob = nncf_model_input(input_blob)
        model(input_blob)
        model.train(prev_training_state)

    def wrap_inputs(args, kwargs):
        assert len(args) == 1
        if isinstance(args[0], TracedTensor):
            logger.info('wrap_inputs: do not wrap input TracedTensor')
            return args, {}
        return (nncf_model_input(args[0]), ), kwargs

    model.dummy_forward_fn = dummy_forward
    if 'log_dir' in nncf_config:
        os.makedirs(nncf_config['log_dir'], exist_ok=True)
    logger.info(f'nncf_config["log_dir"] = {nncf_config["log_dir"]}')

    compression_ctrl, model = create_compressed_model(
        model,
        nncf_config,
        dummy_forward_fn=dummy_forward,
        wrap_inputs_fn=wrap_inputs,
        compression_state=compression_state)

    if model_state_dict:
        logger.info(f'Loading NNCF model from {checkpoint_path}')
        load_state(model, model_state_dict, is_resume=True)

    return compression_ctrl, model, nncf_metainfo
Exemple #10
0
def staged_quantization_main_worker(current_gpu, config):
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)

    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    set_seed(config)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(config.device)

    model_name = config['model']
    is_inception = 'inception' in model_name
    train_criterion_fn = inception_criterion_fn if is_inception else default_criterion_fn

    train_loader = train_sampler = val_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path
    nncf_config = config.nncf_config

    pretrained = is_pretrained_model_requested(config)
    is_export_only = 'export' in config.mode and (
        'train' not in config.mode and 'test' not in config.mode)

    if is_export_only:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        # Data loading code
        train_dataset, val_dataset = create_datasets(config)
        train_loader, train_sampler, val_loader, init_loader = create_data_loaders(
            config, train_dataset, val_dataset)

        def autoq_eval_fn(model, eval_loader):
            _, top5, _ = validate(eval_loader, model, criterion, config)
            return top5

        nncf_config = register_default_init_args(
            nncf_config,
            init_loader,
            criterion=criterion,
            criterion_fn=train_criterion_fn,
            autoq_eval_fn=autoq_eval_fn,
            val_loader=val_loader,
            device=config.device)

    # create model
    model_name = config['model']
    model = load_model(model_name,
                       pretrained=pretrained,
                       num_classes=config.get('num_classes', 1000),
                       model_params=config.get('model_params'),
                       weights_path=config.get('weights'))
    original_model = copy.deepcopy(model)

    model.to(config.device)

    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
    model_state_dict, compression_state = extract_model_and_compression_states(
        resuming_checkpoint)
    compression_ctrl, model = create_compressed_model(model, nncf_config,
                                                      compression_state)
    if model_state_dict is not None:
        load_state(model, model_state_dict, is_resume=True)

    if not isinstance(compression_ctrl,
                      (BinarizationController, QuantizationController)):
        raise RuntimeError(
            "The stage quantization sample worker may only be run with the binarization and quantization algorithms!"
        )

    model, _ = prepare_model_for_execution(model, config)
    original_model.to(config.device)

    if config.distributed:
        compression_ctrl.distributed()

    params_to_optimize = model.parameters()

    compression_config = config['compression']
    quantization_config = compression_config if isinstance(
        compression_config, dict) else compression_config[0]
    optimizer = get_quantization_optimizer(params_to_optimize,
                                           quantization_config)
    optimizer_scheduler = PolyLRDropScheduler(optimizer, quantization_config)
    kd_loss_calculator = KDLossCalculator(original_model)

    best_acc1 = 0
    # optionally resume from a checkpoint
    if resuming_checkpoint is not None and config.to_onnx is None:
        config.start_epoch = resuming_checkpoint['epoch']
        best_acc1 = resuming_checkpoint['best_acc1']
        kd_loss_calculator.original_model.load_state_dict(
            resuming_checkpoint['original_model_state_dict'])
        if 'train' in config.mode:
            optimizer.load_state_dict(resuming_checkpoint['optimizer'])
            optimizer_scheduler.load_state_dict(
                resuming_checkpoint['optimizer_scheduler'])
            logger.info(
                "=> loaded checkpoint '{}' (epoch: {}, best_acc1: {:.3f})".
                format(resuming_checkpoint_path, resuming_checkpoint['epoch'],
                       best_acc1))
        else:
            logger.info(
                "=> loaded checkpoint '{}'".format(resuming_checkpoint_path))

    log_common_mlflow_params(config)

    if is_export_only:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if config.execution_mode != ExecutionMode.CPU_ONLY:
        cudnn.benchmark = True

    if is_main_process():
        statistics = compression_ctrl.statistics()
        logger.info(statistics.to_str())

    if 'train' in config.mode:
        batch_multiplier = (quantization_config.get("params", {})).get(
            "batch_multiplier", 1)
        train_staged(config, compression_ctrl, model, criterion,
                     train_criterion_fn, optimizer_scheduler, model_name,
                     optimizer, train_loader, train_sampler, val_loader,
                     kd_loss_calculator, batch_multiplier, best_acc1)

    if 'test' in config.mode:
        validate(val_loader, model, criterion, config)

    if 'export' in config.mode:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
Exemple #11
0
def main_worker(current_gpu, config):
    #################################
    # Setup experiment environment
    #################################
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)
    if is_on_first_rank(config):
        configure_logging(logger, config)
        print_args(config)

    set_seed(config)

    config.start_iter = 0
    nncf_config = config.nncf_config
    ##########################
    # Prepare metrics log file
    ##########################

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    ###########################
    # Criterion
    ###########################

    criterion = MultiBoxLoss(config,
                             config['num_classes'],
                             overlap_thresh=0.5,
                             prior_for_matching=True,
                             bkg_label=0,
                             neg_mining=True,
                             neg_pos=3,
                             neg_overlap=0.5,
                             encode_target=False,
                             device=config.device)

    train_data_loader = test_data_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path

    ###########################
    # Prepare data
    ###########################

    pretrained = is_pretrained_model_requested(config)

    is_export_only = 'export' in config.mode and (
        'train' not in config.mode and 'test' not in config.mode)
    if is_export_only:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        test_data_loader, train_data_loader, init_data_loader = create_dataloaders(
            config)

        def criterion_fn(model_outputs, target, criterion):
            loss_l, loss_c = criterion(model_outputs, target)
            return loss_l + loss_c

        def autoq_test_fn(model, eval_loader):
            # RL is maximization, change the loss polarity
            return -1 * test_net(model,
                                 config.device,
                                 eval_loader,
                                 distributed=config.distributed,
                                 loss_inference=True,
                                 criterion=criterion)

        def model_eval_fn(model):
            model.eval()
            mAP = test_net(model,
                           config.device,
                           test_data_loader,
                           distributed=config.distributed,
                           criterion=criterion)
            return mAP

        nncf_config = register_default_init_args(nncf_config,
                                                 init_data_loader,
                                                 criterion=criterion,
                                                 criterion_fn=criterion_fn,
                                                 autoq_eval_fn=autoq_test_fn,
                                                 val_loader=test_data_loader,
                                                 model_eval_fn=model_eval_fn,
                                                 device=config.device)

    ##################
    # Prepare model
    ##################
    resuming_checkpoint_path = config.resuming_checkpoint_path

    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
    compression_ctrl, net = create_model(config, resuming_checkpoint)
    if config.distributed:
        config.batch_size //= config.ngpus_per_node
        config.workers //= config.ngpus_per_node
        compression_ctrl.distributed()

    ###########################
    # Optimizer
    ###########################

    params_to_optimize = get_parameter_groups(net, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    #################################
    # Load additional checkpoint data
    #################################

    if resuming_checkpoint_path is not None and 'train' in config.mode:
        optimizer.load_state_dict(
            resuming_checkpoint.get('optimizer', optimizer.state_dict()))
        config.start_epoch = resuming_checkpoint.get('epoch', 0) + 1

    log_common_mlflow_params(config)

    if is_export_only:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if is_main_process():
        statistics = compression_ctrl.statistics()
        logger.info(statistics.to_str())

    if 'train' in config.mode and is_accuracy_aware_training(config):
        # validation function that returns the target metric value
        # pylint: disable=E1123
        def validate_fn(model, epoch):
            model.eval()
            mAP = test_net(model,
                           config.device,
                           test_data_loader,
                           distributed=config.distributed)
            model.train()
            return mAP

        # training function that trains the model for one epoch (full training dataset pass)
        # it is assumed that all the NNCF-related methods are properly called inside of
        # this function (like e.g. the step and epoch_step methods of the compression scheduler)
        def train_epoch_fn(compression_ctrl, model, epoch, optimizer,
                           **kwargs):
            loc_loss = 0
            conf_loss = 0
            epoch_size = len(train_data_loader)
            train_epoch(compression_ctrl, model, config, train_data_loader,
                        criterion, optimizer, epoch_size, epoch, loc_loss,
                        conf_loss)

        # function that initializes optimizers & lr schedulers to start training
        def configure_optimizers_fn():
            params_to_optimize = get_parameter_groups(net, config)
            optimizer, lr_scheduler = make_optimizer(params_to_optimize,
                                                     config)
            return optimizer, lr_scheduler

        acc_aware_training_loop = create_accuracy_aware_training_loop(
            nncf_config, compression_ctrl)
        net = acc_aware_training_loop.run(
            net,
            train_epoch_fn=train_epoch_fn,
            validate_fn=validate_fn,
            configure_optimizers_fn=configure_optimizers_fn,
            tensorboard_writer=config.tb,
            log_dir=config.log_dir)
    elif 'train' in config.mode:
        train(net, compression_ctrl, train_data_loader, test_data_loader,
              criterion, optimizer, config, lr_scheduler)

    if 'test' in config.mode:
        with torch.no_grad():
            net.eval()
            if config['ssd_params'].get('loss_inference', False):
                model_loss = test_net(net,
                                      config.device,
                                      test_data_loader,
                                      distributed=config.distributed,
                                      loss_inference=True,
                                      criterion=criterion)
                logger.info("Final model loss: {:.3f}".format(model_loss))
            else:
                mAp = test_net(net,
                               config.device,
                               test_data_loader,
                               distributed=config.distributed)
                if config.metrics_dump is not None:
                    write_metrics(mAp, config.metrics_dump)

    if 'export' in config.mode:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
Exemple #12
0
def main_worker(current_gpu, config):
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)
    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    set_seed(config)
    logger.info(config)

    dataset = get_dataset(config.dataset)
    color_encoding = dataset.color_encoding
    num_classes = len(color_encoding)

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    train_loader = val_loader = criterion = None
    resuming_checkpoint_path = config.resuming_checkpoint_path

    nncf_config = config.nncf_config

    pretrained = is_pretrained_model_requested(config)

    def criterion_fn(model_outputs, target, criterion_):
        labels, loss_outputs, _ = \
            loss_funcs.do_model_specific_postprocessing(config.model, target, model_outputs)
        return criterion_(loss_outputs, labels)

    is_export_only = 'export' in config.mode and (
        'train' not in config.mode and 'test' not in config.mode)
    if is_export_only:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        loaders, w_class = load_dataset(dataset, config)
        train_loader, val_loader, init_loader = loaders
        criterion = get_criterion(w_class, config)

        def autoq_test_fn(model, eval_loader):
            return test(model, eval_loader, criterion, color_encoding, config)

        model_eval_fn = functools.partial(autoq_test_fn,
                                          eval_loader=val_loader)

        nncf_config = register_default_init_args(nncf_config,
                                                 init_loader,
                                                 criterion=criterion,
                                                 criterion_fn=criterion_fn,
                                                 autoq_eval_fn=autoq_test_fn,
                                                 val_loader=val_loader,
                                                 model_eval_fn=model_eval_fn,
                                                 device=config.device)

    model = load_model(config.model,
                       pretrained=pretrained,
                       num_classes=num_classes,
                       model_params=config.get('model_params', {}),
                       weights_path=config.get('weights'))

    model.to(config.device)

    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
    model_state_dict, compression_state = extract_model_and_compression_states(
        resuming_checkpoint)
    compression_ctrl, model = create_compressed_model(model, nncf_config,
                                                      compression_state)
    if model_state_dict is not None:
        load_state(model, model_state_dict, is_resume=True)
    model, model_without_dp = prepare_model_for_execution(model, config)

    if config.distributed:
        compression_ctrl.distributed()

    log_common_mlflow_params(config)

    if is_export_only:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if is_main_process():
        statistics = compression_ctrl.statistics()
        logger.info(statistics.to_str())

    if is_accuracy_aware_training(config) and 'train' in config.mode:

        def validate_fn(model, epoch):
            return test(model, val_loader, criterion, color_encoding, config)

        # training function that trains the model for one epoch (full training dataset pass)
        # it is assumed that all the NNCF-related methods are properly called inside of
        # this function (like e.g. the step and epoch_step methods of the compression scheduler)
        def train_epoch_fn(compression_ctrl, model, optimizer, **kwargs):
            ignore_index = None
            ignore_unlabeled = config.get("ignore_unlabeled", True)
            if ignore_unlabeled and ('unlabeled' in color_encoding):
                ignore_index = list(color_encoding).index('unlabeled')
            metric = IoU(len(color_encoding), ignore_index=ignore_index)
            train_obj = Train(model, train_loader, optimizer, criterion,
                              compression_ctrl, metric, config.device,
                              config.model)
            train_obj.run_epoch(config.print_step)

        # function that initializes optimizers & lr schedulers to start training
        def configure_optimizers_fn():
            optim_config = config.get('optimizer', {})
            optim_params = optim_config.get('optimizer_params', {})
            lr = optim_params.get("lr", 1e-4)
            params_to_optimize = get_params_to_optimize(
                model_without_dp, lr * 10, config)
            optimizer, lr_scheduler = make_optimizer(params_to_optimize,
                                                     config)
            return optimizer, lr_scheduler

        acc_aware_training_loop = create_accuracy_aware_training_loop(
            config, compression_ctrl)
        model = acc_aware_training_loop.run(
            model,
            train_epoch_fn=train_epoch_fn,
            validate_fn=validate_fn,
            configure_optimizers_fn=configure_optimizers_fn,
            tensorboard_writer=config.tb,
            log_dir=config.log_dir)

    elif 'train' in config.mode:
        train(model, model_without_dp, compression_ctrl, train_loader,
              val_loader, criterion, color_encoding, config,
              resuming_checkpoint)

    if 'test' in config.mode:
        logger.info(model)
        model_parameters = filter(lambda p: p.requires_grad,
                                  model.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        logger.info("Trainable argument count:{params}".format(params=params))
        model = model.to(config.device)
        test(model, val_loader, criterion, color_encoding, config)

    if 'export' in config.mode:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))