def create_initialized_compressed_model(model: nn.Module, config: NNCFConfig, train_loader: DataLoader) -> nn.Module: config = register_default_init_args(deepcopy(config), train_loader, nn.MSELoss) model, _compression_ctrl = create_compressed_model_and_algo_for_test( model, config) return model
def create_finetuned_lenet_model_and_dataloader(config, eval_fn, finetuning_steps, learning_rate=1e-3): with set_torch_seed(): train_loader = create_ones_mock_dataloader(config, num_samples=10) model = LeNet() for param in model.parameters(): nn.init.uniform_(param, a=0.0, b=0.01) data_loader = iter(train_loader) optimizer = SGD(model.parameters(), lr=learning_rate) for _ in range(finetuning_steps): optimizer.zero_grad() x, y_gt = next(data_loader) y = model(x) loss = F.mse_loss(y.sum(), y_gt) loss.backward() optimizer.step() config = register_default_init_args( config, train_loader=train_loader, model_eval_fn=partial(eval_fn, train_loader=train_loader)) model, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) return model, train_loader, compression_ctrl
def test_legr_class_setting_params(tmp_path): generations_ref = 150 train_steps_ref = 50 max_pruning_ref = 0.1 model = PruningTestModel() config = create_default_legr_config() config['compression']['params']['legr_params'] = {} config['compression']['params']['legr_params']['generations'] = generations_ref config['compression']['params']['legr_params']['train_steps'] = train_steps_ref config['compression']['params']['legr_params']['max_pruning'] = max_pruning_ref config['compression']['params']['legr_params']['random_seed'] = 1 train_loader = create_ones_mock_dataloader(config) val_loader = create_ones_mock_dataloader(config) train_steps_fn = lambda *x: None validate_fn = lambda *x: (0, 0) nncf_config = register_default_init_args(config, train_loader=train_loader, train_steps_fn=train_steps_fn, val_loader=val_loader, validate_fn=validate_fn) _, compression_ctrl = create_compressed_model_and_algo_for_test(model, nncf_config) compression_ctrl.legr.num_generations = generations_ref compression_ctrl.legr.max_pruning = max_pruning_ref compression_ctrl.legr._train_steps = train_steps_ref compression_ctrl.legr.random_seed = 1
def test_mock_dump_checkpoint(aa_config): is_called_dump_checkpoint_fn = False def mock_dump_checkpoint_fn(model, compression_controller, accuracy_aware_runner, aa_log_dir): from nncf.api.compression import CompressionAlgorithmController from nncf.common.accuracy_aware_training.runner import TrainingRunner assert isinstance(model, torch.nn.Module) assert isinstance(compression_controller, CompressionAlgorithmController) assert isinstance(accuracy_aware_runner, TrainingRunner) assert isinstance(aa_log_dir, str) nonlocal is_called_dump_checkpoint_fn is_called_dump_checkpoint_fn = True config = get_quantization_config_without_range_init(LeNet.INPUT_SIZE[-1]) train_loader = create_ones_mock_dataloader(aa_config, num_samples=10) model = LeNet() config.update(aa_config) def train_fn(compression_ctrl, model, epoch, optimizer, lr_scheduler, train_loader=train_loader): pass def mock_validate_fn(model, init_step=False, epoch=0): return 80 def configure_optimizers_fn(): optimizer = SGD(model.parameters(), lr=0.001) return optimizer, None config = register_default_init_args(config, train_loader=train_loader, model_eval_fn=partial(mock_validate_fn, init_step=True)) model, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) early_stopping_training_loop = EarlyExitCompressionTrainingLoop( config, compression_ctrl, dump_checkpoints=True) model = early_stopping_training_loop.run( model, train_epoch_fn=train_fn, validate_fn=partial(mock_validate_fn), configure_optimizers_fn=configure_optimizers_fn, dump_checkpoint_fn=mock_dump_checkpoint_fn) assert is_called_dump_checkpoint_fn
def test_legr_class_default_params(tmp_path): model = PruningTestModel() config = create_default_legr_config() train_loader = create_ones_mock_dataloader(config) val_loader = create_ones_mock_dataloader(config) train_steps_fn = lambda *x: None validate_fn = lambda *x: (0, 0) nncf_config = register_default_init_args(config, train_loader=train_loader, train_steps_fn=train_steps_fn, val_loader=val_loader, validate_fn=validate_fn) _, compression_ctrl = create_compressed_model_and_algo_for_test(model, nncf_config) compression_ctrl.legr.num_generations = 400 compression_ctrl.legr.max_pruning = 0.8 compression_ctrl.legr._train_steps = 200 compression_ctrl.legr.random_seed = 42
def test_legr_reproducibility(): np.random.seed(42) config = create_default_legr_config() train_loader = create_ones_mock_dataloader(config) val_loader = create_ones_mock_dataloader(config) train_steps_fn = lambda *x: None validate_fn = lambda *x: (0, np.random.random()) nncf_config = register_default_init_args(config, train_loader=train_loader, train_steps_fn=train_steps_fn, val_loader=val_loader, validate_fn=validate_fn) model_1 = PruningTestModel() _, compression_ctrl_1 = create_compressed_model_and_algo_for_test(model_1, nncf_config) model_2 = PruningTestModel() _, compression_ctrl_2 = create_compressed_model_and_algo_for_test(model_2, config) assert compression_ctrl_1.ranking_coeffs == compression_ctrl_2.ranking_coeffs
def test_accuracy_aware_config(aa_config, must_raise): def mock_validate_fn(model): pass config = get_quantization_config_without_range_init(LeNet.INPUT_SIZE[-1]) config.update({ "accuracy_aware_training": { "mode": "adaptive_compression_level", "params": { "maximal_relative_accuracy_degradation": 1, "initial_training_phase_epochs": 1, "patience_epochs": 10 } } }) config.update(aa_config) train_loader = create_ones_mock_dataloader(config, num_samples=10) model = LeNet() config = register_default_init_args(config, train_loader=train_loader, model_eval_fn=mock_validate_fn) model, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) if must_raise: with pytest.raises(RuntimeError): _ = create_accuracy_aware_training_loop(config, compression_ctrl, dump_checkpoints=False) else: _ = create_accuracy_aware_training_loop(config, compression_ctrl, dump_checkpoints=False)
def test_early_exit_with_mock_validation(max_accuracy_degradation, exit_epoch_number, maximal_total_epochs=100): epoch_counter = 0 def mock_validate_fn(model, init_step=False, epoch=0): original_metric = 0.85 if init_step: return original_metric nonlocal epoch_counter epoch_counter = epoch if "maximal_relative_accuracy_degradation" in max_accuracy_degradation: return original_metric * (1 - 0.01 * max_accuracy_degradation[ 'maximal_relative_accuracy_degradation']) * (epoch / exit_epoch_number) return (original_metric - max_accuracy_degradation['maximal_absolute_accuracy_degradation']) * \ epoch / exit_epoch_number config = get_quantization_config_without_range_init(LeNet.INPUT_SIZE[-1]) params = {"maximal_total_epochs": maximal_total_epochs} params.update(max_accuracy_degradation) accuracy_aware_config = { "accuracy_aware_training": { "mode": "early_exit", "params": params } } config.update(accuracy_aware_config) train_loader = create_ones_mock_dataloader(config, num_samples=10) model = LeNet() config = register_default_init_args(config, train_loader=train_loader, model_eval_fn=partial(mock_validate_fn, init_step=True)) model, compression_ctrl = create_compressed_model_and_algo_for_test( model, config) def train_fn(compression_ctrl, model, epoch, optimizer, lr_scheduler, train_loader=train_loader): pass def configure_optimizers_fn(): return None, None early_stopping_training_loop = EarlyExitCompressionTrainingLoop( config, compression_ctrl, dump_checkpoints=False) model = early_stopping_training_loop.run( model, train_epoch_fn=train_fn, validate_fn=partial(mock_validate_fn), configure_optimizers_fn=configure_optimizers_fn) # Epoch number starts from 0 assert epoch_counter == exit_epoch_number
def wrap_nncf_model(model, cfg, checkpoint_dict=None, datamanager_for_init=None): # Note that we require to import it here to avoid cyclic imports when import get_no_nncf_trace_context_manager # from mobilenetv3 from torchreid.data.transforms import build_inference_transform from nncf import NNCFConfig from nncf.torch import create_compressed_model, load_state from nncf.torch.initialization import register_default_init_args from nncf.torch.dynamic_graph.io_handling import nncf_model_input from nncf.torch.dynamic_graph.trace_tensor import TracedTensor from nncf.torch.initialization import PTInitializingDataLoader if checkpoint_dict is None: checkpoint_path = cfg.model.load_weights resuming_checkpoint = safe_load_checkpoint( checkpoint_path, map_location=torch.device('cpu')) else: checkpoint_path = 'pretrained_dict' resuming_checkpoint = checkpoint_dict if datamanager_for_init is None and not is_nncf_state(resuming_checkpoint): raise RuntimeError('Either datamanager_for_init or NNCF pre-trained ' 'model checkpoint should be set') nncf_metainfo = None if is_nncf_state(resuming_checkpoint): nncf_metainfo = _get_nncf_metainfo_from_state(resuming_checkpoint) nncf_config_data = nncf_metainfo['nncf_config'] datamanager_for_init = None logger.info(f'Read NNCF metainfo with NNCF config from the checkpoint:' f'nncf_metainfo=\n{pformat(nncf_metainfo)}') else: resuming_checkpoint = None nncf_config_data = cfg.get('nncf_config') if nncf_config_data is None: logger.info('Cannot read nncf_config from config file') else: logger.info(f' nncf_config=\n{pformat(nncf_config_data)}') h, w = cfg.data.height, cfg.data.width if not nncf_config_data: logger.info('Using the default NNCF int8 quantization config') nncf_config_data = get_default_nncf_compression_config(h, w) # do it even if nncf_config_data is loaded from a checkpoint -- for the rare case when # the width and height of the model's input was changed in the config # and then finetuning of NNCF model is run nncf_config_data.setdefault('input_info', {}) nncf_config_data['input_info']['sample_size'] = [1, 3, h, w] nncf_config = NNCFConfig(nncf_config_data) logger.info(f'nncf_config =\n{pformat(nncf_config)}') if not nncf_metainfo: nncf_metainfo = create_nncf_metainfo(enable_quantization=True, enable_pruning=False, nncf_config=nncf_config_data) else: # update it just to be on the safe side nncf_metainfo['nncf_config'] = nncf_config_data class ReidInitializeDataLoader(PTInitializingDataLoader): def get_inputs(self, dataloader_output): # define own InitializingDataLoader class using approach like # parse_data_for_train and parse_data_for_eval in the class Engine # dataloader_output[0] should be image here args = (dataloader_output[0], ) return args, {} @torch.no_grad() def model_eval_fn(model): """ Runs evaluation of the model on the validation set and returns the target metric value. Used to evaluate the original model before compression if NNCF-based accuracy-aware training is used. """ from torchreid.metrics.classification import evaluate_classification if test_loader is None: raise RuntimeError( 'Cannot perform a model evaluation on the validation ' 'dataset since the validation data loader was not passed ' 'to wrap_nncf_model') model_type = get_model_attr(model, 'type') targets = list(test_loader.keys()) use_gpu = cur_device.type == 'cuda' for dataset_name in targets: domain = 'source' if dataset_name in datamanager_for_init.sources else 'target' print(f'##### Evaluating {dataset_name} ({domain}) #####') if model_type == 'classification': cmc, _, _ = evaluate_classification( test_loader[dataset_name]['query'], model, use_gpu=use_gpu) accuracy = cmc[0] elif model_type == 'multilabel': mAP, _, _, _, _, _, _ = evaluate_multilabel_classification( test_loader[dataset_name]['query'], model, use_gpu=use_gpu) accuracy = mAP else: raise ValueError( f'Cannot perform a model evaluation on the validation dataset' f'since the model has unsupported model_type {model_type or "None"}' ) return accuracy cur_device = next(model.parameters()).device logger.info(f'NNCF: cur_device = {cur_device}') if resuming_checkpoint is None: logger.info( 'No NNCF checkpoint is provided -- register initialize data loader' ) train_loader = datamanager_for_init.train_loader test_loader = datamanager_for_init.test_loader wrapped_loader = ReidInitializeDataLoader(train_loader) nncf_config = register_default_init_args(nncf_config, wrapped_loader, model_eval_fn=model_eval_fn, device=cur_device) model_state_dict = None compression_state = None else: model_state_dict, compression_state = extract_model_and_compression_states( resuming_checkpoint) transform = build_inference_transform( cfg.data.height, cfg.data.width, norm_mean=cfg.data.norm_mean, norm_std=cfg.data.norm_std, ) def dummy_forward(model): prev_training_state = model.training model.eval() input_img = random_image(cfg.data.height, cfg.data.width) input_blob = transform(input_img).unsqueeze(0) assert len(input_blob.size()) == 4 input_blob = input_blob.to(device=cur_device) input_blob = nncf_model_input(input_blob) model(input_blob) model.train(prev_training_state) def wrap_inputs(args, kwargs): assert len(args) == 1 if isinstance(args[0], TracedTensor): logger.info('wrap_inputs: do not wrap input TracedTensor') return args, {} return (nncf_model_input(args[0]), ), kwargs model.dummy_forward_fn = dummy_forward if 'log_dir' in nncf_config: os.makedirs(nncf_config['log_dir'], exist_ok=True) logger.info(f'nncf_config["log_dir"] = {nncf_config["log_dir"]}') compression_ctrl, model = create_compressed_model( model, nncf_config, dummy_forward_fn=dummy_forward, wrap_inputs_fn=wrap_inputs, compression_state=compression_state) if model_state_dict: logger.info(f'Loading NNCF model from {checkpoint_path}') load_state(model, model_state_dict, is_resume=True) return compression_ctrl, model, nncf_metainfo
def staged_quantization_main_worker(current_gpu, config): configure_device(current_gpu, config) config.mlflow = SafeMLFLow(config) if is_main_process(): configure_logging(logger, config) print_args(config) set_seed(config) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss() criterion = criterion.to(config.device) model_name = config['model'] is_inception = 'inception' in model_name train_criterion_fn = inception_criterion_fn if is_inception else default_criterion_fn train_loader = train_sampler = val_loader = None resuming_checkpoint_path = config.resuming_checkpoint_path nncf_config = config.nncf_config pretrained = is_pretrained_model_requested(config) is_export_only = 'export' in config.mode and ( 'train' not in config.mode and 'test' not in config.mode) if is_export_only: assert pretrained or (resuming_checkpoint_path is not None) else: # Data loading code train_dataset, val_dataset = create_datasets(config) train_loader, train_sampler, val_loader, init_loader = create_data_loaders( config, train_dataset, val_dataset) def autoq_eval_fn(model, eval_loader): _, top5, _ = validate(eval_loader, model, criterion, config) return top5 nncf_config = register_default_init_args( nncf_config, init_loader, criterion=criterion, criterion_fn=train_criterion_fn, autoq_eval_fn=autoq_eval_fn, val_loader=val_loader, device=config.device) # create model model_name = config['model'] model = load_model(model_name, pretrained=pretrained, num_classes=config.get('num_classes', 1000), model_params=config.get('model_params'), weights_path=config.get('weights')) original_model = copy.deepcopy(model) model.to(config.device) resuming_checkpoint = None if resuming_checkpoint_path is not None: resuming_checkpoint = load_resuming_checkpoint( resuming_checkpoint_path) model_state_dict, compression_state = extract_model_and_compression_states( resuming_checkpoint) compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state) if model_state_dict is not None: load_state(model, model_state_dict, is_resume=True) if not isinstance(compression_ctrl, (BinarizationController, QuantizationController)): raise RuntimeError( "The stage quantization sample worker may only be run with the binarization and quantization algorithms!" ) model, _ = prepare_model_for_execution(model, config) original_model.to(config.device) if config.distributed: compression_ctrl.distributed() params_to_optimize = model.parameters() compression_config = config['compression'] quantization_config = compression_config if isinstance( compression_config, dict) else compression_config[0] optimizer = get_quantization_optimizer(params_to_optimize, quantization_config) optimizer_scheduler = PolyLRDropScheduler(optimizer, quantization_config) kd_loss_calculator = KDLossCalculator(original_model) best_acc1 = 0 # optionally resume from a checkpoint if resuming_checkpoint is not None and config.to_onnx is None: config.start_epoch = resuming_checkpoint['epoch'] best_acc1 = resuming_checkpoint['best_acc1'] kd_loss_calculator.original_model.load_state_dict( resuming_checkpoint['original_model_state_dict']) if 'train' in config.mode: optimizer.load_state_dict(resuming_checkpoint['optimizer']) optimizer_scheduler.load_state_dict( resuming_checkpoint['optimizer_scheduler']) logger.info( "=> loaded checkpoint '{}' (epoch: {}, best_acc1: {:.3f})". format(resuming_checkpoint_path, resuming_checkpoint['epoch'], best_acc1)) else: logger.info( "=> loaded checkpoint '{}'".format(resuming_checkpoint_path)) log_common_mlflow_params(config) if is_export_only: compression_ctrl.export_model(config.to_onnx) logger.info("Saved to {}".format(config.to_onnx)) return if config.execution_mode != ExecutionMode.CPU_ONLY: cudnn.benchmark = True if is_main_process(): statistics = compression_ctrl.statistics() logger.info(statistics.to_str()) if 'train' in config.mode: batch_multiplier = (quantization_config.get("params", {})).get( "batch_multiplier", 1) train_staged(config, compression_ctrl, model, criterion, train_criterion_fn, optimizer_scheduler, model_name, optimizer, train_loader, train_sampler, val_loader, kd_loss_calculator, batch_multiplier, best_acc1) if 'test' in config.mode: validate(val_loader, model, criterion, config) if 'export' in config.mode: compression_ctrl.export_model(config.to_onnx) logger.info("Saved to {}".format(config.to_onnx))
def main_worker(current_gpu, config): ################################# # Setup experiment environment ################################# configure_device(current_gpu, config) config.mlflow = SafeMLFLow(config) if is_on_first_rank(config): configure_logging(logger, config) print_args(config) set_seed(config) config.start_iter = 0 nncf_config = config.nncf_config ########################## # Prepare metrics log file ########################## if config.metrics_dump is not None: write_metrics(0, config.metrics_dump) ########################### # Criterion ########################### criterion = MultiBoxLoss(config, config['num_classes'], overlap_thresh=0.5, prior_for_matching=True, bkg_label=0, neg_mining=True, neg_pos=3, neg_overlap=0.5, encode_target=False, device=config.device) train_data_loader = test_data_loader = None resuming_checkpoint_path = config.resuming_checkpoint_path ########################### # Prepare data ########################### pretrained = is_pretrained_model_requested(config) is_export_only = 'export' in config.mode and ( 'train' not in config.mode and 'test' not in config.mode) if is_export_only: assert pretrained or (resuming_checkpoint_path is not None) else: test_data_loader, train_data_loader, init_data_loader = create_dataloaders( config) def criterion_fn(model_outputs, target, criterion): loss_l, loss_c = criterion(model_outputs, target) return loss_l + loss_c def autoq_test_fn(model, eval_loader): # RL is maximization, change the loss polarity return -1 * test_net(model, config.device, eval_loader, distributed=config.distributed, loss_inference=True, criterion=criterion) def model_eval_fn(model): model.eval() mAP = test_net(model, config.device, test_data_loader, distributed=config.distributed, criterion=criterion) return mAP nncf_config = register_default_init_args(nncf_config, init_data_loader, criterion=criterion, criterion_fn=criterion_fn, autoq_eval_fn=autoq_test_fn, val_loader=test_data_loader, model_eval_fn=model_eval_fn, device=config.device) ################## # Prepare model ################## resuming_checkpoint_path = config.resuming_checkpoint_path resuming_checkpoint = None if resuming_checkpoint_path is not None: resuming_checkpoint = load_resuming_checkpoint( resuming_checkpoint_path) compression_ctrl, net = create_model(config, resuming_checkpoint) if config.distributed: config.batch_size //= config.ngpus_per_node config.workers //= config.ngpus_per_node compression_ctrl.distributed() ########################### # Optimizer ########################### params_to_optimize = get_parameter_groups(net, config) optimizer, lr_scheduler = make_optimizer(params_to_optimize, config) ################################# # Load additional checkpoint data ################################# if resuming_checkpoint_path is not None and 'train' in config.mode: optimizer.load_state_dict( resuming_checkpoint.get('optimizer', optimizer.state_dict())) config.start_epoch = resuming_checkpoint.get('epoch', 0) + 1 log_common_mlflow_params(config) if is_export_only: compression_ctrl.export_model(config.to_onnx) logger.info("Saved to {}".format(config.to_onnx)) return if is_main_process(): statistics = compression_ctrl.statistics() logger.info(statistics.to_str()) if 'train' in config.mode and is_accuracy_aware_training(config): # validation function that returns the target metric value # pylint: disable=E1123 def validate_fn(model, epoch): model.eval() mAP = test_net(model, config.device, test_data_loader, distributed=config.distributed) model.train() return mAP # training function that trains the model for one epoch (full training dataset pass) # it is assumed that all the NNCF-related methods are properly called inside of # this function (like e.g. the step and epoch_step methods of the compression scheduler) def train_epoch_fn(compression_ctrl, model, epoch, optimizer, **kwargs): loc_loss = 0 conf_loss = 0 epoch_size = len(train_data_loader) train_epoch(compression_ctrl, model, config, train_data_loader, criterion, optimizer, epoch_size, epoch, loc_loss, conf_loss) # function that initializes optimizers & lr schedulers to start training def configure_optimizers_fn(): params_to_optimize = get_parameter_groups(net, config) optimizer, lr_scheduler = make_optimizer(params_to_optimize, config) return optimizer, lr_scheduler acc_aware_training_loop = create_accuracy_aware_training_loop( nncf_config, compression_ctrl) net = acc_aware_training_loop.run( net, train_epoch_fn=train_epoch_fn, validate_fn=validate_fn, configure_optimizers_fn=configure_optimizers_fn, tensorboard_writer=config.tb, log_dir=config.log_dir) elif 'train' in config.mode: train(net, compression_ctrl, train_data_loader, test_data_loader, criterion, optimizer, config, lr_scheduler) if 'test' in config.mode: with torch.no_grad(): net.eval() if config['ssd_params'].get('loss_inference', False): model_loss = test_net(net, config.device, test_data_loader, distributed=config.distributed, loss_inference=True, criterion=criterion) logger.info("Final model loss: {:.3f}".format(model_loss)) else: mAp = test_net(net, config.device, test_data_loader, distributed=config.distributed) if config.metrics_dump is not None: write_metrics(mAp, config.metrics_dump) if 'export' in config.mode: compression_ctrl.export_model(config.to_onnx) logger.info("Saved to {}".format(config.to_onnx))
def main_worker(current_gpu, config): configure_device(current_gpu, config) config.mlflow = SafeMLFLow(config) if is_main_process(): configure_logging(logger, config) print_args(config) set_seed(config) logger.info(config) dataset = get_dataset(config.dataset) color_encoding = dataset.color_encoding num_classes = len(color_encoding) if config.metrics_dump is not None: write_metrics(0, config.metrics_dump) train_loader = val_loader = criterion = None resuming_checkpoint_path = config.resuming_checkpoint_path nncf_config = config.nncf_config pretrained = is_pretrained_model_requested(config) def criterion_fn(model_outputs, target, criterion_): labels, loss_outputs, _ = \ loss_funcs.do_model_specific_postprocessing(config.model, target, model_outputs) return criterion_(loss_outputs, labels) is_export_only = 'export' in config.mode and ( 'train' not in config.mode and 'test' not in config.mode) if is_export_only: assert pretrained or (resuming_checkpoint_path is not None) else: loaders, w_class = load_dataset(dataset, config) train_loader, val_loader, init_loader = loaders criterion = get_criterion(w_class, config) def autoq_test_fn(model, eval_loader): return test(model, eval_loader, criterion, color_encoding, config) model_eval_fn = functools.partial(autoq_test_fn, eval_loader=val_loader) nncf_config = register_default_init_args(nncf_config, init_loader, criterion=criterion, criterion_fn=criterion_fn, autoq_eval_fn=autoq_test_fn, val_loader=val_loader, model_eval_fn=model_eval_fn, device=config.device) model = load_model(config.model, pretrained=pretrained, num_classes=num_classes, model_params=config.get('model_params', {}), weights_path=config.get('weights')) model.to(config.device) resuming_checkpoint = None if resuming_checkpoint_path is not None: resuming_checkpoint = load_resuming_checkpoint( resuming_checkpoint_path) model_state_dict, compression_state = extract_model_and_compression_states( resuming_checkpoint) compression_ctrl, model = create_compressed_model(model, nncf_config, compression_state) if model_state_dict is not None: load_state(model, model_state_dict, is_resume=True) model, model_without_dp = prepare_model_for_execution(model, config) if config.distributed: compression_ctrl.distributed() log_common_mlflow_params(config) if is_export_only: compression_ctrl.export_model(config.to_onnx) logger.info("Saved to {}".format(config.to_onnx)) return if is_main_process(): statistics = compression_ctrl.statistics() logger.info(statistics.to_str()) if is_accuracy_aware_training(config) and 'train' in config.mode: def validate_fn(model, epoch): return test(model, val_loader, criterion, color_encoding, config) # training function that trains the model for one epoch (full training dataset pass) # it is assumed that all the NNCF-related methods are properly called inside of # this function (like e.g. the step and epoch_step methods of the compression scheduler) def train_epoch_fn(compression_ctrl, model, optimizer, **kwargs): ignore_index = None ignore_unlabeled = config.get("ignore_unlabeled", True) if ignore_unlabeled and ('unlabeled' in color_encoding): ignore_index = list(color_encoding).index('unlabeled') metric = IoU(len(color_encoding), ignore_index=ignore_index) train_obj = Train(model, train_loader, optimizer, criterion, compression_ctrl, metric, config.device, config.model) train_obj.run_epoch(config.print_step) # function that initializes optimizers & lr schedulers to start training def configure_optimizers_fn(): optim_config = config.get('optimizer', {}) optim_params = optim_config.get('optimizer_params', {}) lr = optim_params.get("lr", 1e-4) params_to_optimize = get_params_to_optimize( model_without_dp, lr * 10, config) optimizer, lr_scheduler = make_optimizer(params_to_optimize, config) return optimizer, lr_scheduler acc_aware_training_loop = create_accuracy_aware_training_loop( config, compression_ctrl) model = acc_aware_training_loop.run( model, train_epoch_fn=train_epoch_fn, validate_fn=validate_fn, configure_optimizers_fn=configure_optimizers_fn, tensorboard_writer=config.tb, log_dir=config.log_dir) elif 'train' in config.mode: train(model, model_without_dp, compression_ctrl, train_loader, val_loader, criterion, color_encoding, config, resuming_checkpoint) if 'test' in config.mode: logger.info(model) model_parameters = filter(lambda p: p.requires_grad, model.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) logger.info("Trainable argument count:{params}".format(params=params)) model = model.to(config.device) test(model, val_loader, criterion, color_encoding, config) if 'export' in config.mode: compression_ctrl.export_model(config.to_onnx) logger.info("Saved to {}".format(config.to_onnx))