コード例 #1
0
def test_zeroing_gradients(zero_grad):
    """
    Test for zeroing gradients functionality (zero_grads_for_pruned_modules in base algo)
    :param zero_grad: zero grad or not
    """
    config = get_basic_pruning_config(input_sample_size=(2, 1, 8, 8))
    config['compression']['params']['prune_first_conv'] = True
    config['compression']['params']['prune_last_conv'] = True
    config['compression']['params']['zero_grad'] = zero_grad

    pruned_model, pruning_algo, _ = create_pruning_algo_with_config(config)
    assert pruning_algo.zero_grad is zero_grad

    pruned_module_info = pruning_algo.pruned_module_info
    pruned_modules = [minfo.module for minfo in pruned_module_info]

    device = next(pruned_model.parameters()).device
    data_loader = create_dataloader(config)

    pruning_algo.initialize(data_loader)

    params_to_optimize = get_parameter_groups(pruned_model, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    lr_scheduler.step(0)

    pruned_model.train()
    for input_, target in data_loader:
        input_ = input_.to(device)
        target = target.to(device).view(1)

        output = pruned_model(input_)

        loss = torch.sum(target.to(torch.float32) - output)

        optimizer.zero_grad()
        loss.backward()

        # In case of zero_grad = True gradients should be masked
        if zero_grad:
            for module in pruned_modules:
                op = list(module.pre_ops.values())[0]
                mask = op.operand.binary_filter_pruning_mask
                grad = module.weight.grad
                masked_grad = apply_filter_binary_mask(mask, grad)
                assert torch.allclose(masked_grad, grad)
コード例 #2
0
def train(model, model_without_dp, compression_ctrl, train_loader, val_loader,
          criterion, class_encoding, config, resuming_checkpoint):
    logger.info("\nTraining...\n")

    # Check if the network architecture is correct
    logger.info(model)

    optim_config = config.get('optimizer', {})
    optim_params = optim_config.get('optimizer_params', {})
    lr = optim_params.get("lr", 1e-4)

    params_to_optimize = get_params_to_optimize(model_without_dp, lr * 10,
                                                config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    # Evaluation metric

    ignore_index = None
    ignore_unlabeled = config.get("ignore_unlabeled", True)
    if ignore_unlabeled and ('unlabeled' in class_encoding):
        ignore_index = list(class_encoding).index('unlabeled')

    metric = IoU(len(class_encoding), ignore_index=ignore_index)

    best_miou = -1
    best_compression_level = CompressionLevel.NONE
    # Optionally resume from a checkpoint
    if resuming_checkpoint is not None:
        if optimizer is not None:
            optimizer.load_state_dict(resuming_checkpoint['optimizer'])
        start_epoch = resuming_checkpoint['epoch']
        best_miou = resuming_checkpoint['miou']

        if "scheduler" in resuming_checkpoint and compression_ctrl.scheduler is not None:
            compression_ctrl.scheduler.load_state_dict(
                resuming_checkpoint['scheduler'])
        logger.info("Resuming from model: Start epoch = {0} "
                    "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
        config.start_epoch = start_epoch

    # Start Training
    train_obj = Train(model, train_loader, optimizer, criterion,
                      compression_ctrl, metric, config.device, config.model)
    val_obj = Test(model, val_loader, criterion, metric, config.device,
                   config.model)

    for epoch in range(config.start_epoch, config.epochs):
        compression_ctrl.scheduler.epoch_step()
        logger.info(">>>> [Epoch: {0:d}] Training".format(epoch))

        if config.distributed:
            train_loader.sampler.set_epoch(epoch)

        epoch_loss, (iou, miou) = train_obj.run_epoch(config.print_step)
        if not isinstance(lr_scheduler, ReduceLROnPlateau):
            # Learning rate scheduling should be applied after optimizer’s update
            lr_scheduler.step(epoch)

        logger.info(
            ">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
            format(epoch, epoch_loss, miou))

        if is_main_process():
            config.tb.add_scalar("train/loss", epoch_loss, epoch)
            config.tb.add_scalar("train/mIoU", miou, epoch)
            config.tb.add_scalar("train/learning_rate",
                                 optimizer.param_groups[0]['lr'], epoch)
            config.tb.add_scalar("train/compression_loss",
                                 compression_ctrl.loss(), epoch)

            for key, value in compression_ctrl.statistics(
                    quickly_collected_only=True).items():
                if isinstance(value, (int, float)):
                    config.tb.add_scalar(
                        "compression/statistics/{0}".format(key), value, epoch)

        if (epoch + 1) % config.save_freq == 0 or epoch + 1 == config.epochs:
            logger.info(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, (iou, miou) = val_obj.run_epoch(config.print_step)

            logger.info(
                ">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                format(epoch, loss, miou))

            if is_main_process():
                config.tb.add_scalar("val/mIoU", miou, epoch)
                config.tb.add_scalar("val/loss", loss, epoch)
                for i, (key,
                        class_iou) in enumerate(zip(class_encoding.keys(),
                                                    iou)):
                    config.tb.add_scalar(
                        "{}/mIoU_Cls{}_{}".format(config.dataset, i, key),
                        class_iou, epoch)

            compression_level = compression_ctrl.compression_level()
            is_best_by_miou = miou > best_miou and compression_level == best_compression_level
            is_best = is_best_by_miou or compression_level > best_compression_level
            if is_best:
                best_miou = miou
            best_compression_level = max(compression_level,
                                         best_compression_level)

            if config.metrics_dump is not None:
                write_metrics(best_miou, config.metrics_dump)

            if isinstance(lr_scheduler, ReduceLROnPlateau):
                # Learning rate scheduling should be applied after optimizer’s update
                lr_scheduler.step(best_miou)

            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == config.epochs or is_best:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    logger.info("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if is_main_process():
                checkpoint_path = save_checkpoint(model, optimizer, epoch,
                                                  best_miou, compression_level,
                                                  compression_ctrl.scheduler,
                                                  config)

                make_additional_checkpoints(checkpoint_path, is_best, epoch,
                                            config)
                print_statistics(compression_ctrl.statistics())

    return model
コード例 #3
0
def train(model, model_without_dp, compression_algo, train_loader, val_loader, class_weights, class_encoding, config):
    print("\nTraining...\n")

    # Check if the network architecture is correct
    print(model)

    optim_config = config.get('optimizer', {})
    optim_params = optim_config.get('optimizer_params', {})
    lr = optim_params.get("lr", 1e-4)

    params_to_optimize, criterion = get_aux_loss_dependent_params(model_without_dp,
                                                                  class_weights,
                                                                  lr * 10,
                                                                  config)

    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    # Evaluation metric

    ignore_index = None
    ignore_unlabeled = config.get("ignore_unlabeled", True)
    if ignore_unlabeled and ('unlabeled' in class_encoding):
        ignore_index = list(class_encoding).index('unlabeled')

    metric = IoU(len(class_encoding), ignore_index=ignore_index)

    best_miou = -1
    resuming_checkpoint = config.resuming_checkpoint
    # Optionally resume from a checkpoint
    if resuming_checkpoint is not None:
        model, optimizer, start_epoch, best_miou, _ = \
            load_checkpoint(
                model, resuming_checkpoint, config.device,
                optimizer, compression_algo.scheduler)
        print("Resuming from model: Start epoch = {0} "
              "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
        config.start_epoch = start_epoch

    # Start Training
    train_obj = Train(model, train_loader, optimizer, criterion, compression_algo, metric, config.device,
                      config.model)
    val_obj = Test(model, val_loader, criterion, metric, config.device,
                   config.model)

    for epoch in range(config.start_epoch, config.epochs):
        print(">>>> [Epoch: {0:d}] Training".format(epoch))

        if config.distributed:
            train_loader.sampler.set_epoch(epoch)

        if not isinstance(lr_scheduler, ReduceLROnPlateau):
            lr_scheduler.step(epoch)

        epoch_loss, (iou, miou) = train_obj.run_epoch(config.print_step)
        compression_algo.scheduler.epoch_step()

        print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
              format(epoch, epoch_loss, miou))

        if is_main_process():
            config.tb.add_scalar("train/loss", epoch_loss, epoch)
            config.tb.add_scalar("train/mIoU", miou, epoch)
            config.tb.add_scalar("train/learning_rate", optimizer.param_groups[0]['lr'], epoch)
            config.tb.add_scalar("train/compression_loss", compression_algo.loss(), epoch)

            for key, value in compression_algo.statistics().items():
                if isinstance(value, (int, float)):
                    config.tb.add_scalar("compression/statistics/{0}".format(key), value, epoch)

        if (epoch + 1) % config.save_freq == 0 or epoch + 1 == config.epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, (iou, miou) = val_obj.run_epoch(config.print_step)

            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(epoch, loss, miou))

            if is_main_process():
                config.tb.add_scalar("val/mIoU", miou, epoch)
                config.tb.add_scalar("val/loss", loss, epoch)
                for i, (key, class_iou) in enumerate(zip(class_encoding.keys(), iou)):
                    config.tb.add_scalar("{}/mIoU_Cls{}_{}".format(config.dataset, i, key), class_iou, epoch)

            is_best = miou > best_miou
            best_miou = max(miou, best_miou)

            if isinstance(lr_scheduler, ReduceLROnPlateau):
                lr_scheduler.step(best_miou)

            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == config.epochs or is_best:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if is_main_process():
                checkpoint_path = save_checkpoint(model,
                                                  optimizer, epoch + 1, best_miou,
                                                  compression_algo.scheduler, config)

                make_additional_checkpoints(checkpoint_path, is_best, epoch + 1, config)
                print_statistics(compression_algo.statistics())

    return model
コード例 #4
0
def main_worker(current_gpu, config):
    config.current_gpu = current_gpu
    config.distributed = config.execution_mode in (
        ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        configure_distributed(config)

    config.device = get_device(config)

    if is_main_process():
        configure_logging(config)
        print_args(config)

    if config.seed is not None:
        manual_seed(config.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False

    # create model
    model_name = config['model']
    weights = config.get('weights')
    model = load_model(model_name,
                       pretrained=config.get('pretrained', True)
                       if weights is None else False,
                       num_classes=config.get('num_classes', 1000),
                       model_params=config.get('model_params'))
    compression_algo, model = create_compressed_model(model, config)
    if weights:
        load_state(model, torch.load(weights, map_location='cpu'))
    model, _ = prepare_model_for_execution(model, config)
    if config.distributed:
        compression_algo.distributed()

    is_inception = 'inception' in model_name

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(config.device)

    params_to_optimize = get_parameter_groups(model, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    resuming_checkpoint = config.resuming_checkpoint
    best_acc1 = 0
    # optionally resume from a checkpoint
    if resuming_checkpoint is not None:
        model, config, optimizer, compression_algo, best_acc1 = \
            resume_from_checkpoint(resuming_checkpoint, model,
                                   config, optimizer, compression_algo)

    if config.to_onnx is not None:
        compression_algo.export_model(config.to_onnx)
        print("Saved to", config.to_onnx)
        return

    if config.execution_mode != ExecutionMode.CPU_ONLY:
        cudnn.benchmark = True

    # Data loading code
    train_loader, train_sampler, val_loader = create_dataloaders(config)

    if config.mode.lower() == 'test':
        print_statistics(compression_algo.statistics())
        validate(val_loader, model, criterion, config)

    if config.mode.lower() == 'train':
        if not resuming_checkpoint:
            compression_algo.initialize(train_loader)
        train(config, compression_algo, model, criterion, is_inception,
              lr_scheduler, model_name, optimizer, train_loader, train_sampler,
              val_loader, best_acc1)
コード例 #5
0
def main_worker(current_gpu, config):
    #################################
    # Setup experiment environment
    #################################
    config.current_gpu = current_gpu
    config.distributed = config.execution_mode in (ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        configure_distributed(config)
    if is_on_first_rank(config):
        configure_logging(config)
        print_args(config)

    config.device = get_device(config)
    config.start_iter = 0

    ##########################
    # Prepare metrics log file
    ##########################

    if config.metrics_dump and config.resuming_checkpoint is not None:
        avg = 0
        metrics = {os.path.basename(config.resuming_checkpoint): avg}
        write_metrics(config, metrics)

    ##################
    # Prepare model
    ##################

    compression_algo, net = create_model(config)
    if config.distributed:
        config.batch_size //= config.ngpus_per_node
        config.workers //= config.ngpus_per_node
        compression_algo.distributed()

    ###########################
    # Criterion and optimizer
    ###########################

    params_to_optimize = get_parameter_groups(net, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    criterion = MultiBoxLoss(
        config,
        config['num_classes'],
        overlap_thresh=0.5,
        prior_for_matching=True,
        bkg_label=0,
        neg_mining=True,
        neg_pos=3,
        neg_overlap=0.5,
        encode_target=False,
        device=config.device
    )

    ###########################
    # Load checkpoint
    ###########################

    resuming_checkpoint = config.resuming_checkpoint
    if resuming_checkpoint:
        print('Resuming training, loading {}...'.format(resuming_checkpoint))
        checkpoint = torch.load(resuming_checkpoint, map_location='cpu')
        # use checkpoint itself in case of only state dict is saved
        # i.e. checkpoint is created with `torch.save(module.state_dict())`
        state_dict = checkpoint.get('state_dict', checkpoint)
        load_state(net, state_dict, is_resume=True)
        if config.mode.lower() == 'train' and config.to_onnx is None:
            compression_algo.scheduler.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint.get('optimizer', optimizer.state_dict()))
            config.start_iter = checkpoint.get('iter', 0) + 1

    if config.to_onnx:
        compression_algo.export_model(config.to_onnx)
        print("Saved to {}".format(config.to_onnx))
        return

    ###########################
    # Prepare data
    ###########################

    test_data_loader, train_data_loader = create_dataloaders(config)

    if config.mode.lower() == 'test':
        with torch.no_grad():
            print_statistics(compression_algo.statistics())
            net.eval()
            mAp = test_net(net, config.device, test_data_loader, distributed=config.distributed)
            if config.metrics_dump and config.resuming_checkpoint is not None:
                avg = mAp*100
                metrics = {os.path.basename(config.resuming_checkpoint): round(avg, 2)}
                write_metrics(config, metrics)
            return

    if not resuming_checkpoint:
        compression_algo.initialize(train_data_loader)

    train(net, compression_algo, train_data_loader, test_data_loader, criterion, optimizer, config, lr_scheduler)
コード例 #6
0
def main_worker(current_gpu, config: SampleConfig):
    config.current_gpu = current_gpu
    config.distributed = config.execution_mode in (
        ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        configure_distributed(config)

    config.device = get_device(config)

    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    if config.seed is not None:
        manual_seed(config.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(config.device)

    train_loader = train_sampler = val_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path
    nncf_config = config.nncf_config

    pretrained = is_pretrained_model_requested(config)

    if config.to_onnx is not None:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        # Data loading code
        train_dataset, val_dataset = create_datasets(config)
        train_loader, train_sampler, val_loader = create_data_loaders(
            config, train_dataset, val_dataset)
        nncf_config = register_default_init_args(nncf_config, criterion,
                                                 train_loader)

    # create model
    model_name = config['model']
    model = load_model(model_name,
                       pretrained=pretrained,
                       num_classes=config.get('num_classes', 1000),
                       model_params=config.get('model_params'),
                       weights_path=config.get('weights'))

    model.to(config.device)

    resuming_model_sd = None
    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
        resuming_model_sd = resuming_checkpoint['state_dict']

    compression_ctrl, model = create_compressed_model(
        model, nncf_config, resuming_state_dict=resuming_model_sd)

    if config.to_onnx:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    model, _ = prepare_model_for_execution(model, config)
    if config.distributed:
        compression_ctrl.distributed()

    # define optimizer
    params_to_optimize = get_parameter_groups(model, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    best_acc1 = 0
    # optionally resume from a checkpoint
    if resuming_checkpoint_path is not None:
        if config.mode.lower() == 'train' and config.to_onnx is None:
            config.start_epoch = resuming_checkpoint['epoch']
            best_acc1 = resuming_checkpoint['best_acc1']
            compression_ctrl.scheduler.load_state_dict(
                resuming_checkpoint['scheduler'])
            optimizer.load_state_dict(resuming_checkpoint['optimizer'])
            logger.info(
                "=> loaded checkpoint '{}' (epoch: {}, best_acc1: {:.3f})".
                format(resuming_checkpoint_path, resuming_checkpoint['epoch'],
                       best_acc1))
        else:
            logger.info(
                "=> loaded checkpoint '{}'".format(resuming_checkpoint_path))

    if config.execution_mode != ExecutionMode.CPU_ONLY:
        cudnn.benchmark = True

    if config.mode.lower() == 'test':
        print_statistics(compression_ctrl.statistics())
        validate(val_loader, model, criterion, config)

    if config.mode.lower() == 'train':
        is_inception = 'inception' in model_name
        train(config, compression_ctrl, model, criterion, is_inception,
              lr_scheduler, model_name, optimizer, train_loader, train_sampler,
              val_loader, best_acc1)
コード例 #7
0
ファイル: main.py プロジェクト: zmaslova/nncf_pytorch
def main_worker(current_gpu, config):
    #################################
    # Setup experiment environment
    #################################
    config.current_gpu = current_gpu
    config.distributed = config.execution_mode in (
        ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        configure_distributed(config)
    if is_on_first_rank(config):
        configure_logging(logger, config)
        print_args(config)

    config.device = get_device(config)
    config.start_iter = 0

    ##########################
    # Prepare metrics log file
    ##########################

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    ###########################
    # Criterion
    ###########################

    criterion = MultiBoxLoss(config,
                             config['num_classes'],
                             overlap_thresh=0.5,
                             prior_for_matching=True,
                             bkg_label=0,
                             neg_mining=True,
                             neg_pos=3,
                             neg_overlap=0.5,
                             encode_target=False,
                             device=config.device)

    train_data_loader = test_data_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path

    ###########################
    # Prepare data
    ###########################

    pretrained = is_pretrained_model_requested(config)

    if config.to_onnx is not None:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        test_data_loader, train_data_loader = create_dataloaders(config)
        config.nncf_config = register_default_init_args(
            config.nncf_config, criterion, train_data_loader)

    ##################
    # Prepare model
    ##################
    resuming_checkpoint_path = config.resuming_checkpoint_path
    resuming_checkpoint = None
    resuming_model_state_dict = None

    if resuming_checkpoint_path:
        logger.info(
            'Resuming from checkpoint {}...'.format(resuming_checkpoint_path))
        resuming_checkpoint = torch.load(resuming_checkpoint_path,
                                         map_location='cpu')
        # use checkpoint itself in case only the state dict was saved,
        # i.e. the checkpoint was created with `torch.save(module.state_dict())`
        resuming_model_state_dict = resuming_checkpoint.get(
            'state_dict', resuming_checkpoint)

    compression_ctrl, net = create_model(config, resuming_model_state_dict)
    if config.distributed:
        config.batch_size //= config.ngpus_per_node
        config.workers //= config.ngpus_per_node
        compression_ctrl.distributed()

    ###########################
    # Optimizer
    ###########################

    params_to_optimize = get_parameter_groups(net, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    #################################
    # Load additional checkpoint data
    #################################

    if resuming_checkpoint is not None and config.mode.lower(
    ) == 'train' and config.to_onnx is None:
        compression_ctrl.scheduler.load_state_dict(
            resuming_checkpoint['scheduler'])
        optimizer.load_state_dict(
            resuming_checkpoint.get('optimizer', optimizer.state_dict()))
        config.start_iter = resuming_checkpoint.get('iter', 0) + 1

    if config.to_onnx:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if config.mode.lower() == 'test':
        with torch.no_grad():
            print_statistics(compression_ctrl.statistics())
            net.eval()
            mAp = test_net(net,
                           config.device,
                           test_data_loader,
                           distributed=config.distributed)
            if config.metrics_dump is not None:
                write_metrics(mAp, config.metrics_dump)
            return

    train(net, compression_ctrl, train_data_loader, test_data_loader,
          criterion, optimizer, config, lr_scheduler)
コード例 #8
0
def main_worker(current_gpu, config):
    #################################
    # Setup experiment environment
    #################################
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)
    if is_on_first_rank(config):
        configure_logging(logger, config)
        print_args(config)

    config.start_iter = 0
    nncf_config = config.nncf_config
    ##########################
    # Prepare metrics log file
    ##########################

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    ###########################
    # Criterion
    ###########################

    criterion = MultiBoxLoss(config,
                             config['num_classes'],
                             overlap_thresh=0.5,
                             prior_for_matching=True,
                             bkg_label=0,
                             neg_mining=True,
                             neg_pos=3,
                             neg_overlap=0.5,
                             encode_target=False,
                             device=config.device)

    train_data_loader = test_data_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path

    ###########################
    # Prepare data
    ###########################

    pretrained = is_pretrained_model_requested(config)

    if config.to_onnx is not None:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        test_data_loader, train_data_loader, init_data_loader = create_dataloaders(
            config)

        def criterion_fn(model_outputs, target, criterion):
            loss_l, loss_c = criterion(model_outputs, target)
            return loss_l + loss_c

        def autoq_test_fn(model, eval_loader):
            # RL is maximization, change the loss polarity
            return -1 * test_net(model,
                                 config.device,
                                 eval_loader,
                                 distributed=config.distributed,
                                 loss_inference=True,
                                 criterion=criterion)

        nncf_config = register_default_init_args(nncf_config, init_data_loader,
                                                 criterion, criterion_fn,
                                                 autoq_test_fn,
                                                 test_data_loader,
                                                 config.device)

    ##################
    # Prepare model
    ##################
    resuming_checkpoint_path = config.resuming_checkpoint_path

    resuming_model_sd = None
    if resuming_checkpoint_path is not None:
        resuming_model_sd, resuming_checkpoint = load_resuming_model_state_dict_and_checkpoint_from_path(
            resuming_checkpoint_path)

    compression_ctrl, net = create_model(config, resuming_model_sd)
    if config.distributed:
        config.batch_size //= config.ngpus_per_node
        config.workers //= config.ngpus_per_node
        compression_ctrl.distributed()

    ###########################
    # Optimizer
    ###########################

    params_to_optimize = get_parameter_groups(net, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    #################################
    # Load additional checkpoint data
    #################################

    if resuming_checkpoint_path is not None and config.mode.lower(
    ) == 'train' and config.to_onnx is None:
        compression_ctrl.scheduler.load_state_dict(
            resuming_checkpoint['scheduler'])
        optimizer.load_state_dict(
            resuming_checkpoint.get('optimizer', optimizer.state_dict()))
        config.start_iter = resuming_checkpoint.get('iter', 0) + 1

    log_common_mlflow_params(config)

    if config.to_onnx:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if is_main_process():
        print_statistics(compression_ctrl.statistics())

    if config.mode.lower() == 'test':
        with torch.no_grad():
            net.eval()
            if config['ssd_params'].get('loss_inference', False):
                model_loss = test_net(net,
                                      config.device,
                                      test_data_loader,
                                      distributed=config.distributed,
                                      loss_inference=True,
                                      criterion=criterion)
                logger.info("Final model loss: {:.3f}".format(model_loss))
            else:
                mAp = test_net(net,
                               config.device,
                               test_data_loader,
                               distributed=config.distributed)
                if config.metrics_dump is not None:
                    write_metrics(mAp, config.metrics_dump)
            return

    train(net, compression_ctrl, train_data_loader, test_data_loader,
          criterion, optimizer, config, lr_scheduler)