def train_bin(config, compression_ctrl, model, criterion, is_inception, optimizer_scheduler, model_name, optimizer,
              train_loader, train_sampler, val_loader, kd_loss_calculator, batch_multiplier, best_acc1=0):
    for epoch in range(config.start_epoch, config.epochs):
        config.cur_epoch = epoch
        if config.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_epoch_bin(train_loader, batch_multiplier, model, criterion, optimizer, optimizer_scheduler,
                        kd_loss_calculator, compression_ctrl, epoch, config, is_inception)

        # compute compression algo statistics
        stats = compression_ctrl.statistics()

        acc1 = best_acc1
        if epoch % config.test_every_n_epochs == 0:
            # evaluate on validation set
            acc1, _ = validate(val_loader, model, criterion, config)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        # update compression scheduler state at the end of the epoch
        compression_ctrl.scheduler.epoch_step()
        optimizer_scheduler.epoch_step()

        if is_main_process():
            print_statistics(stats)

            checkpoint_path = osp.join(config.checkpoint_save_dir, get_name(config) + '_last.pth')
            checkpoint = {
                'epoch': epoch + 1,
                'arch': model_name,
                'state_dict': model.state_dict(),
                'original_model_state_dict': kd_loss_calculator.original_model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
                'compression_scheduler': compression_ctrl.scheduler.state_dict(),
                'optimizer_scheduler': optimizer_scheduler.state_dict()
            }

            torch.save(checkpoint, checkpoint_path)
            make_additional_checkpoints(checkpoint_path, is_best, epoch + 1, config)

            for key, value in stats.items():
                if isinstance(value, (int, float)):
                    config.tb.add_scalar("compression/statistics/{0}".format(key), value, len(train_loader) * epoch)
def main_worker_binarization(current_gpu, config):
    config.current_gpu = current_gpu
    config.distributed = config.execution_mode in (ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        configure_distributed(config)

    config.device = get_device(config)

    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    if config.seed is not None:
        manual_seed(config.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False

    # create model
    model_name = config['model']
    weights = config.get('weights')
    model = load_model(model_name,
                       pretrained=config.get('pretrained', True) if weights is None else False,
                       num_classes=config.get('num_classes', 1000),
                       model_params=config.get('model_params'))

    original_model = copy.deepcopy(model)
    compression_ctrl, model = create_compressed_model(model, config)
    if not isinstance(compression_ctrl, BinarizationController):
        raise RuntimeError("The binarization sample worker may only be run with the binarization algorithm!")

    if weights:
        load_state(model, torch.load(weights, map_location='cpu'))

    model, _ = prepare_model_for_execution(model, config)
    original_model.to(config.device)

    if config.distributed:
        compression_ctrl.distributed()

    is_inception = 'inception' in model_name

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(config.device)

    params_to_optimize = model.parameters()

    compression_config = config['compression']
    binarization_config = compression_config if isinstance(compression_config, dict) else compression_config[0]
    optimizer = get_binarization_optimizer(params_to_optimize, binarization_config)
    optimizer_scheduler = BinarizationOptimizerScheduler(optimizer, binarization_config)
    kd_loss_calculator = KDLossCalculator(original_model)

    resuming_checkpoint = config.resuming_checkpoint
    best_acc1 = 0
    # optionally resume from a checkpoint
    if resuming_checkpoint is not None:
        model, config, optimizer, optimizer_scheduler, kd_loss_calculator, compression_ctrl, best_acc1 = \
            resume_from_checkpoint(resuming_checkpoint, model,
                                   config, optimizer, optimizer_scheduler, kd_loss_calculator, compression_ctrl)

    if config.to_onnx is not None:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if config.execution_mode != ExecutionMode.CPU_ONLY:
        cudnn.benchmark = True

    # Data loading code
    train_dataset, val_dataset = create_datasets(config)
    train_loader, train_sampler, val_loader = create_data_loaders(config, train_dataset, val_dataset)

    if config.mode.lower() == 'test':
        print_statistics(compression_ctrl.statistics())
        validate(val_loader, model, criterion, config)

    if config.mode.lower() == 'train':
        if not resuming_checkpoint:
            compression_ctrl.initialize(data_loader=train_loader, criterion=criterion)

        batch_multiplier = (binarization_config.get("params", {})).get("batch_multiplier", 1)
        train_bin(config, compression_ctrl, model, criterion, is_inception, optimizer_scheduler, model_name, optimizer,
                  train_loader, train_sampler, val_loader, kd_loss_calculator, batch_multiplier, best_acc1)
def staged_quantization_main_worker(current_gpu, config):
    config.current_gpu = current_gpu
    config.distributed = config.execution_mode in (
        ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        configure_distributed(config)

    config.device = get_device(config)

    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    if config.seed is not None:
        manual_seed(config.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(config.device)

    train_loader = train_sampler = val_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path
    nncf_config = config.nncf_config

    pretrained = is_pretrained_model_requested(config)

    if config.to_onnx is not None:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        # Data loading code
        train_dataset, val_dataset = create_datasets(config)
        train_loader, train_sampler, val_loader = create_data_loaders(
            config, train_dataset, val_dataset)
        nncf_config = register_default_init_args(nncf_config, criterion,
                                                 train_loader)

    # create model
    model_name = config['model']
    model = load_model(model_name,
                       pretrained=pretrained,
                       num_classes=config.get('num_classes', 1000),
                       model_params=config.get('model_params'),
                       weights_path=config.get('weights'))
    original_model = copy.deepcopy(model)

    model.to(config.device)

    resuming_model_sd = None
    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
        resuming_model_sd = resuming_checkpoint['state_dict']

    compression_ctrl, model = create_compressed_model(model, nncf_config,
                                                      resuming_model_sd)
    if not isinstance(compression_ctrl,
                      (BinarizationController, QuantizationController)):
        raise RuntimeError(
            "The stage quantization sample worker may only be run with the binarization and quantization algorithms!"
        )

    model, _ = prepare_model_for_execution(model, config)
    original_model.to(config.device)

    if config.distributed:
        compression_ctrl.distributed()

    is_inception = 'inception' in model_name

    params_to_optimize = model.parameters()

    compression_config = config['compression']
    quantization_config = compression_config if isinstance(
        compression_config, dict) else compression_config[0]
    optimizer = get_quantization_optimizer(params_to_optimize,
                                           quantization_config)
    optimizer_scheduler = PolyLRDropScheduler(optimizer, quantization_config)
    kd_loss_calculator = KDLossCalculator(original_model)

    best_acc1 = 0
    # optionally resume from a checkpoint
    if resuming_checkpoint is not None and config.to_onnx is None:
        config.start_epoch = resuming_checkpoint['epoch']
        best_acc1 = resuming_checkpoint['best_acc1']
        kd_loss_calculator.original_model.load_state_dict(
            resuming_checkpoint['original_model_state_dict'])
        compression_ctrl.scheduler.load_state_dict(
            resuming_checkpoint['compression_scheduler'])
        optimizer.load_state_dict(resuming_checkpoint['optimizer'])
        optimizer_scheduler.load_state_dict(
            resuming_checkpoint['optimizer_scheduler'])
        if config.mode.lower() == 'train':
            logger.info(
                "=> loaded checkpoint '{}' (epoch: {}, best_acc1: {:.3f})".
                format(resuming_checkpoint_path, resuming_checkpoint['epoch'],
                       best_acc1))
        else:
            logger.info(
                "=> loaded checkpoint '{}'".format(resuming_checkpoint_path))

    if config.to_onnx:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if config.execution_mode != ExecutionMode.CPU_ONLY:
        cudnn.benchmark = True

    if config.mode.lower() == 'test':
        print_statistics(compression_ctrl.statistics())
        validate(val_loader, model, criterion, config)

    if config.mode.lower() == 'train':
        batch_multiplier = (quantization_config.get("params", {})).get(
            "batch_multiplier", 1)
        train_staged(config, compression_ctrl, model, criterion, is_inception,
                     optimizer_scheduler, model_name, optimizer, train_loader,
                     train_sampler, val_loader, kd_loss_calculator,
                     batch_multiplier, best_acc1)
def train_staged(config,
                 compression_ctrl,
                 model,
                 criterion,
                 is_inception,
                 optimizer_scheduler,
                 model_name,
                 optimizer,
                 train_loader,
                 train_sampler,
                 val_loader,
                 kd_loss_calculator,
                 batch_multiplier,
                 best_acc1=0):
    best_compression_level = CompressionLevel.NONE
    for epoch in range(config.start_epoch, config.epochs):
        config.cur_epoch = epoch
        if config.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_epoch_staged(train_loader, batch_multiplier, model, criterion,
                           optimizer, optimizer_scheduler, kd_loss_calculator,
                           compression_ctrl, epoch, config, is_inception)

        # compute compression algo statistics
        stats = compression_ctrl.statistics()

        acc1 = best_acc1
        if epoch % config.test_every_n_epochs == 0:
            # evaluate on validation set
            acc1, _ = validate(val_loader, model, criterion, config)

        compression_level = compression_ctrl.compression_level()
        # remember best acc@1, considering compression level. If current acc@1 less then the best acc@1, checkpoint
        # still can be best if current compression level is bigger then best one. Compression levels in ascending
        # order: NONE, PARTIAL, FULL.
        is_best_by_accuracy = acc1 > best_acc1 and compression_level == best_compression_level
        is_best = is_best_by_accuracy or compression_level > best_compression_level
        best_acc1 = max(acc1, best_acc1)
        best_compression_level = max(compression_level, best_compression_level)

        # statistics (e.g. portion of the enabled quantizers) is related to the finished epoch,
        # hence printing should happen before epoch_step, which may inform about state of the next epoch (e.g. next
        # portion of enabled quantizers)
        if is_main_process():
            print_statistics(stats)

        # update compression scheduler state at the end of the epoch
        compression_ctrl.scheduler.epoch_step()
        optimizer_scheduler.epoch_step()

        if is_main_process():
            checkpoint_path = osp.join(config.checkpoint_save_dir,
                                       get_name(config) + '_last.pth')
            checkpoint = {
                'epoch':
                epoch + 1,
                'arch':
                model_name,
                'state_dict':
                model.state_dict(),
                'original_model_state_dict':
                kd_loss_calculator.original_model.state_dict(),
                'best_acc1':
                best_acc1,
                'compression_level':
                compression_level,
                'optimizer':
                optimizer.state_dict(),
                'compression_scheduler':
                compression_ctrl.scheduler.state_dict(),
                'optimizer_scheduler':
                optimizer_scheduler.state_dict()
            }

            torch.save(checkpoint, checkpoint_path)
            make_additional_checkpoints(checkpoint_path, is_best, epoch + 1,
                                        config)

            for key, value in stats.items():
                if isinstance(value, (int, float)):
                    config.tb.add_scalar(
                        "compression/statistics/{0}".format(key), value,
                        len(train_loader) * epoch)
Exemplo n.º 5
0
 def autoq_eval_fn(model, eval_loader):
     _, top5 = validate(eval_loader, model, criterion, config)
     return top5