Example #1
0
File: main.py Project: zbrnwpu/nncf
def create_model(config: SampleConfig, resuming_model_sd: dict = None):
    input_info_list = create_input_infos(config.nncf_config)
    image_size = input_info_list[0].shape[-1]
    ssd_net = build_ssd(config.model, config.ssd_params, image_size, config.num_classes, config)
    weights = config.get('weights')
    if weights:
        sd = torch.load(weights, map_location='cpu')
        load_state(ssd_net, sd)

    ssd_net.to(config.device)

    compression_ctrl, compressed_model = create_compressed_model(ssd_net, config.nncf_config, resuming_model_sd)
    compressed_model, _ = prepare_model_for_execution(compressed_model, config)

    compressed_model.train()
    return compression_ctrl, compressed_model
Example #2
0
def is_pretrained_model_requested(config: SampleConfig) -> bool:
    return config.get('pretrained',
                      True) if config.get('weights') is None else False
Example #3
0
def main_worker(current_gpu, config: SampleConfig):
    config.current_gpu = current_gpu
    config.distributed = config.execution_mode in (
        ExecutionMode.DISTRIBUTED, ExecutionMode.MULTIPROCESSING_DISTRIBUTED)
    if config.distributed:
        configure_distributed(config)

    config.device = get_device(config)

    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    if config.seed is not None:
        manual_seed(config.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False

    # define loss function (criterion)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(config.device)

    train_loader = train_sampler = val_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path
    nncf_config = config.nncf_config

    pretrained = is_pretrained_model_requested(config)

    if config.to_onnx is not None:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        # Data loading code
        train_dataset, val_dataset = create_datasets(config)
        train_loader, train_sampler, val_loader = create_data_loaders(
            config, train_dataset, val_dataset)
        nncf_config = register_default_init_args(nncf_config, criterion,
                                                 train_loader)

    # create model
    model_name = config['model']
    model = load_model(model_name,
                       pretrained=pretrained,
                       num_classes=config.get('num_classes', 1000),
                       model_params=config.get('model_params'),
                       weights_path=config.get('weights'))

    model.to(config.device)

    resuming_model_sd = None
    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
        resuming_model_sd = resuming_checkpoint['state_dict']

    compression_ctrl, model = create_compressed_model(
        model, nncf_config, resuming_state_dict=resuming_model_sd)

    if config.to_onnx:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    model, _ = prepare_model_for_execution(model, config)
    if config.distributed:
        compression_ctrl.distributed()

    # define optimizer
    params_to_optimize = get_parameter_groups(model, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    best_acc1 = 0
    # optionally resume from a checkpoint
    if resuming_checkpoint_path is not None:
        if config.mode.lower() == 'train' and config.to_onnx is None:
            config.start_epoch = resuming_checkpoint['epoch']
            best_acc1 = resuming_checkpoint['best_acc1']
            compression_ctrl.scheduler.load_state_dict(
                resuming_checkpoint['scheduler'])
            optimizer.load_state_dict(resuming_checkpoint['optimizer'])
            logger.info(
                "=> loaded checkpoint '{}' (epoch: {}, best_acc1: {:.3f})".
                format(resuming_checkpoint_path, resuming_checkpoint['epoch'],
                       best_acc1))
        else:
            logger.info(
                "=> loaded checkpoint '{}'".format(resuming_checkpoint_path))

    if config.execution_mode != ExecutionMode.CPU_ONLY:
        cudnn.benchmark = True

    if config.mode.lower() == 'test':
        print_statistics(compression_ctrl.statistics())
        validate(val_loader, model, criterion, config)

    if config.mode.lower() == 'train':
        is_inception = 'inception' in model_name
        train(config, compression_ctrl, model, criterion, is_inception,
              lr_scheduler, model_name, optimizer, train_loader, train_sampler,
              val_loader, best_acc1)