def worker(rank: int, world_size: int) -> None:
    torch.distributed.init_process_group(backend="nccl",
                                         init_method='tcp://127.0.0.1:8999',
                                         world_size=world_size,
                                         rank=rank)
    model = TestModelWithChangedTrain(freezing_stages=1)
    model.cuda()
    model.to(rank)

    nncf_config = NNCFConfig()
    nncf_config.update({
        "input_info": {
            "sample_size": [1, 1, 30, 30]
        },
        "compression": {
            "algorithm": "quantization",
            "initializer": {
                "range": {
                    "num_init_samples": 10
                },
                "batchnorm_adaptation": {
                    "num_bn_adaptation_samples": 10
                }
            }
        }
    })
    dataloader = create_random_mock_dataloader(nncf_config, num_samples=10)
    register_default_init_args(nncf_config, dataloader)

    _, compressed_model = create_compressed_model(model, nncf_config)

    # At this part the additional processes may be freezing

    _ = torch.nn.parallel.DistributedDataParallel(compressed_model,
                                                  device_ids=[rank])
def test_quantization_configuration_stats(data):
    config = get_basic_quantization_config()
    config['compression']['ignored_scopes'] = data.ignored_scopes
    config['input_info']['sample_size'] = [2, 3, 299, 299]

    ctrl, _ = create_compressed_model(
        test_models.Inception3(aux_logits=True, transform_input=True), config)
    stats = ShareEdgesQuantizedDataPathStatisticsCollector(ctrl.model,
                                                           ctrl).collect()

    for attr_name, expected_value in data.expected.items():
        actual_value = as_dict(getattr(stats, attr_name))
        assert expected_value == actual_value
def test_memory_consumption_stats(data):
    config = get_basic_quantization_config()
    config['compression']['initializer'].update(data.initializers)
    config['compression']['weights'] = data.weights
    config['compression']['ignored_scopes'] = data.ignored_scopes
    config['target_device'] = data.target_device

    ctrl, _ = create_compressed_model(test_models.AlexNet(), config)
    stats = MemoryConsumptionStatisticsCollector(
        ctrl.model, ctrl.weight_quantizers,
        ctrl.non_weight_quantizers).collect()

    for attr_name, expected_value in data.expected.items():
        actual_value = getattr(stats, attr_name)
        assert expected_value == pytest.approx(actual_value, rel=1e-2)
def test_quantization_share_and_bitwidth_distribution_stats(data):
    config = get_basic_quantization_config()
    config['compression']['initializer'].update(data.initializers)
    config['compression']['activations'] = data.activations
    config['compression']['weights'] = data.weights
    config['compression']['ignored_scopes'] = data.ignored_scopes
    config['target_device'] = data.target_device

    ctrl, _ = create_compressed_model(test_models.AlexNet(), config)
    nncf_stats = ctrl.statistics()
    quantization_stats = nncf_stats.quantization

    for attr_name, expected_value in data.expected.items():
        actual_value = as_dict(getattr(quantization_stats, attr_name))
        assert expected_value == actual_value
Ejemplo n.º 5
0
def create_model(config: SampleConfig, resuming_checkpoint: dict = None):
    input_info_list = create_input_infos(config.nncf_config)
    image_size = input_info_list[0].shape[-1]
    ssd_net = build_ssd(config.model, config.ssd_params, image_size,
                        config.num_classes, config)
    weights = config.get('weights')
    if weights:
        sd = torch.load(weights,
                        map_location='cpu',
                        pickle_module=restricted_pickle_module)
        sd = sd["state_dict"]
        load_state(ssd_net, sd)

    ssd_net.to(config.device)
    model_state_dict, compression_state = extract_model_and_compression_states(
        resuming_checkpoint)
    compression_ctrl, compressed_model = create_compressed_model(
        ssd_net, config.nncf_config, compression_state)
    if model_state_dict is not None:
        load_state(compressed_model, model_state_dict, is_resume=True)
    compressed_model, _ = prepare_model_for_execution(compressed_model, config)

    compressed_model.train()
    return compression_ctrl, compressed_model
Ejemplo n.º 6
0
def wrap_nncf_model(model,
                    cfg,
                    checkpoint_dict=None,
                    datamanager_for_init=None):
    # Note that we require to import it here to avoid cyclic imports when import get_no_nncf_trace_context_manager
    # from mobilenetv3
    from torchreid.data.transforms import build_inference_transform

    from nncf import NNCFConfig
    from nncf.torch import create_compressed_model, load_state
    from nncf.torch.initialization import register_default_init_args
    from nncf.torch.dynamic_graph.io_handling import nncf_model_input
    from nncf.torch.dynamic_graph.trace_tensor import TracedTensor
    from nncf.torch.initialization import PTInitializingDataLoader

    if checkpoint_dict is None:
        checkpoint_path = cfg.model.load_weights
        resuming_checkpoint = safe_load_checkpoint(
            checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint_path = 'pretrained_dict'
        resuming_checkpoint = checkpoint_dict

    if datamanager_for_init is None and not is_nncf_state(resuming_checkpoint):
        raise RuntimeError('Either datamanager_for_init or NNCF pre-trained '
                           'model checkpoint should be set')

    nncf_metainfo = None
    if is_nncf_state(resuming_checkpoint):
        nncf_metainfo = _get_nncf_metainfo_from_state(resuming_checkpoint)
        nncf_config_data = nncf_metainfo['nncf_config']
        datamanager_for_init = None
        logger.info(f'Read NNCF metainfo with NNCF config from the checkpoint:'
                    f'nncf_metainfo=\n{pformat(nncf_metainfo)}')
    else:
        resuming_checkpoint = None
        nncf_config_data = cfg.get('nncf_config')

        if nncf_config_data is None:
            logger.info('Cannot read nncf_config from config file')
        else:
            logger.info(f' nncf_config=\n{pformat(nncf_config_data)}')

    h, w = cfg.data.height, cfg.data.width
    if not nncf_config_data:
        logger.info('Using the default NNCF int8 quantization config')
        nncf_config_data = get_default_nncf_compression_config(h, w)

    # do it even if nncf_config_data is loaded from a checkpoint -- for the rare case when
    # the width and height of the model's input was changed in the config
    # and then finetuning of NNCF model is run
    nncf_config_data.setdefault('input_info', {})
    nncf_config_data['input_info']['sample_size'] = [1, 3, h, w]

    nncf_config = NNCFConfig(nncf_config_data)
    logger.info(f'nncf_config =\n{pformat(nncf_config)}')
    if not nncf_metainfo:
        nncf_metainfo = create_nncf_metainfo(enable_quantization=True,
                                             enable_pruning=False,
                                             nncf_config=nncf_config_data)
    else:
        # update it just to be on the safe side
        nncf_metainfo['nncf_config'] = nncf_config_data

    class ReidInitializeDataLoader(PTInitializingDataLoader):
        def get_inputs(self, dataloader_output):
            # define own InitializingDataLoader class using approach like
            # parse_data_for_train and parse_data_for_eval in the class Engine
            # dataloader_output[0] should be image here
            args = (dataloader_output[0], )
            return args, {}

    @torch.no_grad()
    def model_eval_fn(model):
        """
        Runs evaluation of the model on the validation set and
        returns the target metric value.
        Used to evaluate the original model before compression
        if NNCF-based accuracy-aware training is used.
        """
        from torchreid.metrics.classification import evaluate_classification

        if test_loader is None:
            raise RuntimeError(
                'Cannot perform a model evaluation on the validation '
                'dataset since the validation data loader was not passed '
                'to wrap_nncf_model')

        model_type = get_model_attr(model, 'type')
        targets = list(test_loader.keys())
        use_gpu = cur_device.type == 'cuda'
        for dataset_name in targets:
            domain = 'source' if dataset_name in datamanager_for_init.sources else 'target'
            print(f'##### Evaluating {dataset_name} ({domain}) #####')
            if model_type == 'classification':
                cmc, _, _ = evaluate_classification(
                    test_loader[dataset_name]['query'], model, use_gpu=use_gpu)
                accuracy = cmc[0]
            elif model_type == 'multilabel':
                mAP, _, _, _, _, _, _ = evaluate_multilabel_classification(
                    test_loader[dataset_name]['query'], model, use_gpu=use_gpu)
                accuracy = mAP
            else:
                raise ValueError(
                    f'Cannot perform a model evaluation on the validation dataset'
                    f'since the model has unsupported model_type {model_type or "None"}'
                )

        return accuracy

    cur_device = next(model.parameters()).device
    logger.info(f'NNCF: cur_device = {cur_device}')

    if resuming_checkpoint is None:
        logger.info(
            'No NNCF checkpoint is provided -- register initialize data loader'
        )
        train_loader = datamanager_for_init.train_loader
        test_loader = datamanager_for_init.test_loader
        wrapped_loader = ReidInitializeDataLoader(train_loader)
        nncf_config = register_default_init_args(nncf_config,
                                                 wrapped_loader,
                                                 model_eval_fn=model_eval_fn,
                                                 device=cur_device)
        model_state_dict = None
        compression_state = None
    else:
        model_state_dict, compression_state = extract_model_and_compression_states(
            resuming_checkpoint)

    transform = build_inference_transform(
        cfg.data.height,
        cfg.data.width,
        norm_mean=cfg.data.norm_mean,
        norm_std=cfg.data.norm_std,
    )

    def dummy_forward(model):
        prev_training_state = model.training
        model.eval()
        input_img = random_image(cfg.data.height, cfg.data.width)
        input_blob = transform(input_img).unsqueeze(0)
        assert len(input_blob.size()) == 4
        input_blob = input_blob.to(device=cur_device)
        input_blob = nncf_model_input(input_blob)
        model(input_blob)
        model.train(prev_training_state)

    def wrap_inputs(args, kwargs):
        assert len(args) == 1
        if isinstance(args[0], TracedTensor):
            logger.info('wrap_inputs: do not wrap input TracedTensor')
            return args, {}
        return (nncf_model_input(args[0]), ), kwargs

    model.dummy_forward_fn = dummy_forward
    if 'log_dir' in nncf_config:
        os.makedirs(nncf_config['log_dir'], exist_ok=True)
    logger.info(f'nncf_config["log_dir"] = {nncf_config["log_dir"]}')

    compression_ctrl, model = create_compressed_model(
        model,
        nncf_config,
        dummy_forward_fn=dummy_forward,
        wrap_inputs_fn=wrap_inputs,
        compression_state=compression_state)

    if model_state_dict:
        logger.info(f'Loading NNCF model from {checkpoint_path}')
        load_state(model, model_state_dict, is_resume=True)

    return compression_ctrl, model, nncf_metainfo
Ejemplo n.º 7
0
def wrap_nncf_model(model,
                    cfg,
                    data_loader_for_init=None,
                    get_fake_input_func=None,
                    export=False):
    """
    The function wraps mmaction model by NNCF
    Note that the parameter `get_fake_input_func` should be the function `get_fake_input`
    -- cannot import this function here explicitly
    """

    check_nncf_is_enabled()

    from nncf.config import NNCFConfig
    from nncf.torch import (create_compressed_model,
                            register_default_init_args)
    from nncf.torch.dynamic_graph.io_handling import nncf_model_input
    from nncf.torch.dynamic_graph.trace_tensor import TracedTensor
    from nncf.torch.initialization import DefaultInitializingDataLoader

    class MMInitializeDataLoader(DefaultInitializingDataLoader):
        def get_inputs(self, dataloader_output):
            return (), dataloader_output

    pathlib.Path(cfg.work_dir).mkdir(parents=True, exist_ok=True)
    nncf_config = NNCFConfig(cfg.nncf_config)
    logger = get_root_logger(cfg.log_level)

    if data_loader_for_init:
        wrapped_loader = MMInitializeDataLoader(data_loader_for_init)
        nncf_config = register_default_init_args(
            nncf_config,
            wrapped_loader,
            device=next(model.parameters()).device)

    if cfg.get('resume_from'):
        checkpoint_path = cfg.get('resume_from')
        assert is_checkpoint_nncf(checkpoint_path), (
            'It is possible to resume training with NNCF compression from NNCF checkpoints only. '
            'Use "load_from" with non-compressed model for further compression by NNCF.'
        )
    elif cfg.get('load_from'):
        checkpoint_path = cfg.get('load_from')
        if not is_checkpoint_nncf(checkpoint_path):
            checkpoint_path = None
            logger.info('Received non-NNCF checkpoint to start training '
                        '-- initialization of NNCF fields will be done')
    else:
        checkpoint_path = None

    if not data_loader_for_init and not checkpoint_path:
        raise RuntimeError('Either data_loader_for_init or NNCF pre-trained '
                           'model checkpoint should be set')

    if checkpoint_path:
        logger.info(f'Loading NNCF checkpoint from {checkpoint_path}')
        logger.info(
            'Please, note that this first loading is made before addition of '
            'NNCF FakeQuantize nodes to the model, so there may be some '
            'warnings on unexpected keys')
        resuming_state_dict = load_checkpoint(model, checkpoint_path)
        logger.info(f'Loaded NNCF checkpoint from {checkpoint_path}')
    else:
        resuming_state_dict = None

    if "nncf_compress_postprocessing" in cfg:
        # NB: This parameter is used to choose if we should try to make NNCF compression
        #     for a whole model graph including postprocessing (`nncf_compress_postprocessing=True`),
        #     or make NNCF compression of the part of the model without postprocessing
        #     (`nncf_compress_postprocessing=False`).
        #     Our primary goal is to make NNCF compression of such big part of the model as
        #     possible, so `nncf_compress_postprocessing=True` is our primary choice, whereas
        #     `nncf_compress_postprocessing=False` is our fallback decision.
        #     When we manage to enable NNCF compression for sufficiently many models,
        #     we should keep one choice only.
        nncf_compress_postprocessing = cfg.get('nncf_compress_postprocessing')
        logger.debug('set should_compress_postprocessing='
                     f'{nncf_compress_postprocessing}')
    else:
        nncf_compress_postprocessing = True

    def _get_fake_data_for_forward(cfg, nncf_config, get_fake_input_func):
        input_size = nncf_config.get("input_info").get('sample_size')
        assert get_fake_input_func is not None
        assert len(input_size) == 4 and input_size[0] == 1
        H, W, C = input_size[2], input_size[3], input_size[1]
        device = next(model.parameters()).device
        with no_nncf_trace():
            return get_fake_input_func(cfg,
                                       orig_img_shape=tuple([H, W, C]),
                                       device=device)

    def dummy_forward(model):
        fake_data = _get_fake_data_for_forward(cfg, nncf_config,
                                               get_fake_input_func)
        img = fake_data["imgs"]
        img = nncf_model_input(img)
        if export:
            img, _, _ = model.reshape_input(imgs=img)
            model(imgs=img)
        else:
            model(imgs=img, return_loss=False)

    def wrap_inputs(args, kwargs):
        # during dummy_forward
        if not len(kwargs):
            if not isinstance(args[0][0], TracedTensor):
                args[0][0] = nncf_model_input(args[0][0])
            return args, kwargs

        # during building original graph
        if not kwargs.get('return_loss') and kwargs.get('forward_export'):
            return args, kwargs

        # during model's forward
        assert 'imgs' in kwargs, 'During model forward imgs must be in kwargs'
        img = kwargs['imgs']
        if isinstance(img, list):
            assert len(img) == 1, 'Input list must have a length 1'
            assert torch.is_tensor(
                img[0]), 'Input for a model must be a tensor'
            if not isinstance(img[0], TracedTensor):
                img[0] = nncf_model_input(img[0])
        else:
            assert torch.is_tensor(img), 'Input for a model must be a tensor'
            if not isinstance(img, TracedTensor):
                img = nncf_model_input(img)
        kwargs['imgs'] = img
        return args, kwargs

    model.dummy_forward_fn = dummy_forward

    if 'log_dir' in nncf_config:
        os.makedirs(nncf_config['log_dir'], exist_ok=True)
    compression_ctrl, model = create_compressed_model(
        model,
        nncf_config,
        dummy_forward_fn=dummy_forward,
        wrap_inputs_fn=wrap_inputs,
        compression_state=resuming_state_dict)

    return compression_ctrl, model
Ejemplo n.º 8
0
def staged_quantization_main_worker(current_gpu, config):
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)

    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    set_seed(config)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(config.device)

    model_name = config['model']
    is_inception = 'inception' in model_name
    train_criterion_fn = inception_criterion_fn if is_inception else default_criterion_fn

    train_loader = train_sampler = val_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path
    nncf_config = config.nncf_config

    pretrained = is_pretrained_model_requested(config)
    is_export_only = 'export' in config.mode and (
        'train' not in config.mode and 'test' not in config.mode)

    if is_export_only:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        # Data loading code
        train_dataset, val_dataset = create_datasets(config)
        train_loader, train_sampler, val_loader, init_loader = create_data_loaders(
            config, train_dataset, val_dataset)

        def autoq_eval_fn(model, eval_loader):
            _, top5, _ = validate(eval_loader, model, criterion, config)
            return top5

        nncf_config = register_default_init_args(
            nncf_config,
            init_loader,
            criterion=criterion,
            criterion_fn=train_criterion_fn,
            autoq_eval_fn=autoq_eval_fn,
            val_loader=val_loader,
            device=config.device)

    # create model
    model_name = config['model']
    model = load_model(model_name,
                       pretrained=pretrained,
                       num_classes=config.get('num_classes', 1000),
                       model_params=config.get('model_params'),
                       weights_path=config.get('weights'))
    original_model = copy.deepcopy(model)

    model.to(config.device)

    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
    model_state_dict, compression_state = extract_model_and_compression_states(
        resuming_checkpoint)
    compression_ctrl, model = create_compressed_model(model, nncf_config,
                                                      compression_state)
    if model_state_dict is not None:
        load_state(model, model_state_dict, is_resume=True)

    if not isinstance(compression_ctrl,
                      (BinarizationController, QuantizationController)):
        raise RuntimeError(
            "The stage quantization sample worker may only be run with the binarization and quantization algorithms!"
        )

    model, _ = prepare_model_for_execution(model, config)
    original_model.to(config.device)

    if config.distributed:
        compression_ctrl.distributed()

    params_to_optimize = model.parameters()

    compression_config = config['compression']
    quantization_config = compression_config if isinstance(
        compression_config, dict) else compression_config[0]
    optimizer = get_quantization_optimizer(params_to_optimize,
                                           quantization_config)
    optimizer_scheduler = PolyLRDropScheduler(optimizer, quantization_config)
    kd_loss_calculator = KDLossCalculator(original_model)

    best_acc1 = 0
    # optionally resume from a checkpoint
    if resuming_checkpoint is not None and config.to_onnx is None:
        config.start_epoch = resuming_checkpoint['epoch']
        best_acc1 = resuming_checkpoint['best_acc1']
        kd_loss_calculator.original_model.load_state_dict(
            resuming_checkpoint['original_model_state_dict'])
        if 'train' in config.mode:
            optimizer.load_state_dict(resuming_checkpoint['optimizer'])
            optimizer_scheduler.load_state_dict(
                resuming_checkpoint['optimizer_scheduler'])
            logger.info(
                "=> loaded checkpoint '{}' (epoch: {}, best_acc1: {:.3f})".
                format(resuming_checkpoint_path, resuming_checkpoint['epoch'],
                       best_acc1))
        else:
            logger.info(
                "=> loaded checkpoint '{}'".format(resuming_checkpoint_path))

    log_common_mlflow_params(config)

    if is_export_only:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if config.execution_mode != ExecutionMode.CPU_ONLY:
        cudnn.benchmark = True

    if is_main_process():
        statistics = compression_ctrl.statistics()
        logger.info(statistics.to_str())

    if 'train' in config.mode:
        batch_multiplier = (quantization_config.get("params", {})).get(
            "batch_multiplier", 1)
        train_staged(config, compression_ctrl, model, criterion,
                     train_criterion_fn, optimizer_scheduler, model_name,
                     optimizer, train_loader, train_sampler, val_loader,
                     kd_loss_calculator, batch_multiplier, best_acc1)

    if 'test' in config.mode:
        validate(val_loader, model, criterion, config)

    if 'export' in config.mode:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
Ejemplo n.º 9
0
def main_worker(current_gpu, config):
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)
    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    set_seed(config)
    logger.info(config)

    dataset = get_dataset(config.dataset)
    color_encoding = dataset.color_encoding
    num_classes = len(color_encoding)

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    train_loader = val_loader = criterion = None
    resuming_checkpoint_path = config.resuming_checkpoint_path

    nncf_config = config.nncf_config

    pretrained = is_pretrained_model_requested(config)

    def criterion_fn(model_outputs, target, criterion_):
        labels, loss_outputs, _ = \
            loss_funcs.do_model_specific_postprocessing(config.model, target, model_outputs)
        return criterion_(loss_outputs, labels)

    is_export_only = 'export' in config.mode and (
        'train' not in config.mode and 'test' not in config.mode)
    if is_export_only:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        loaders, w_class = load_dataset(dataset, config)
        train_loader, val_loader, init_loader = loaders
        criterion = get_criterion(w_class, config)

        def autoq_test_fn(model, eval_loader):
            return test(model, eval_loader, criterion, color_encoding, config)

        model_eval_fn = functools.partial(autoq_test_fn,
                                          eval_loader=val_loader)

        nncf_config = register_default_init_args(nncf_config,
                                                 init_loader,
                                                 criterion=criterion,
                                                 criterion_fn=criterion_fn,
                                                 autoq_eval_fn=autoq_test_fn,
                                                 val_loader=val_loader,
                                                 model_eval_fn=model_eval_fn,
                                                 device=config.device)

    model = load_model(config.model,
                       pretrained=pretrained,
                       num_classes=num_classes,
                       model_params=config.get('model_params', {}),
                       weights_path=config.get('weights'))

    model.to(config.device)

    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
    model_state_dict, compression_state = extract_model_and_compression_states(
        resuming_checkpoint)
    compression_ctrl, model = create_compressed_model(model, nncf_config,
                                                      compression_state)
    if model_state_dict is not None:
        load_state(model, model_state_dict, is_resume=True)
    model, model_without_dp = prepare_model_for_execution(model, config)

    if config.distributed:
        compression_ctrl.distributed()

    log_common_mlflow_params(config)

    if is_export_only:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if is_main_process():
        statistics = compression_ctrl.statistics()
        logger.info(statistics.to_str())

    if is_accuracy_aware_training(config) and 'train' in config.mode:

        def validate_fn(model, epoch):
            return test(model, val_loader, criterion, color_encoding, config)

        # training function that trains the model for one epoch (full training dataset pass)
        # it is assumed that all the NNCF-related methods are properly called inside of
        # this function (like e.g. the step and epoch_step methods of the compression scheduler)
        def train_epoch_fn(compression_ctrl, model, optimizer, **kwargs):
            ignore_index = None
            ignore_unlabeled = config.get("ignore_unlabeled", True)
            if ignore_unlabeled and ('unlabeled' in color_encoding):
                ignore_index = list(color_encoding).index('unlabeled')
            metric = IoU(len(color_encoding), ignore_index=ignore_index)
            train_obj = Train(model, train_loader, optimizer, criterion,
                              compression_ctrl, metric, config.device,
                              config.model)
            train_obj.run_epoch(config.print_step)

        # function that initializes optimizers & lr schedulers to start training
        def configure_optimizers_fn():
            optim_config = config.get('optimizer', {})
            optim_params = optim_config.get('optimizer_params', {})
            lr = optim_params.get("lr", 1e-4)
            params_to_optimize = get_params_to_optimize(
                model_without_dp, lr * 10, config)
            optimizer, lr_scheduler = make_optimizer(params_to_optimize,
                                                     config)
            return optimizer, lr_scheduler

        acc_aware_training_loop = create_accuracy_aware_training_loop(
            config, compression_ctrl)
        model = acc_aware_training_loop.run(
            model,
            train_epoch_fn=train_epoch_fn,
            validate_fn=validate_fn,
            configure_optimizers_fn=configure_optimizers_fn,
            tensorboard_writer=config.tb,
            log_dir=config.log_dir)

    elif 'train' in config.mode:
        train(model, model_without_dp, compression_ctrl, train_loader,
              val_loader, criterion, color_encoding, config,
              resuming_checkpoint)

    if 'test' in config.mode:
        logger.info(model)
        model_parameters = filter(lambda p: p.requires_grad,
                                  model.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        logger.info("Trainable argument count:{params}".format(params=params))
        model = model.to(config.device)
        test(model, val_loader, criterion, color_encoding, config)

    if 'export' in config.mode:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))