Esempio n. 1
0
def load_detection_annotations(cachedir, dataset):
    cachefile = os.path.join(cachedir, 'annots_{}.json'.format(dataset.name))
    imagenames = dataset.get_img_names()
    if is_main_process():
        if not os.path.isfile(cachefile):
            # load annots
            gt = {}
            for i, imagename in enumerate(imagenames):
                _, gt[imagename] = dataset.pull_anno(i)

                if i % 100 == 0:
                    logger.info('Reading annotation for {:d}/{:d}'.format(
                        i + 1, len(imagenames)))
            # save
            logger.info('Saving cached annotations to {:s}'.format(cachefile))
            pathlib.Path(cachedir).mkdir(parents=True, exist_ok=True)
            with open(cachefile, 'w', encoding='utf8') as f:
                json.dump(gt, f)
    if is_dist_avail_and_initialized():
        dist.barrier()
    with open(cachefile, 'r', encoding='utf8') as f:
        gt = json.load(f)
    return gt, imagenames
Esempio n. 2
0
def train_epoch_staged(train_loader, batch_multiplier, model, criterion,
                       criterion_fn, optimizer,
                       optimizer_scheduler: PolyLRDropScheduler,
                       kd_loss_calculator: KDLossCalculator, compression_ctrl,
                       epoch, config):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    kd_losses_meter = AverageMeter()
    criterion_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    compression_scheduler = compression_ctrl.scheduler

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input_, target) in enumerate(train_loader):
        compression_scheduler.step()
        # measure data loading time
        data_time.update(time.time() - end)

        input_ = input_.to(config.device)
        target = target.to(config.device)

        output = model(input_)
        criterion_loss = criterion_fn(output, target, criterion)

        if isinstance(output, InceptionOutputs):
            output = output.logits

        # compute KD loss
        kd_loss = kd_loss_calculator.loss(input_, output)
        loss = criterion_loss + kd_loss

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input_.size(0))
        comp_loss_val = kd_loss.item()
        kd_losses_meter.update(comp_loss_val, input_.size(0))
        criterion_losses.update(criterion_loss.item(), input_.size(0))
        top1.update(acc1, input_.size(0))
        top1.update(acc1, input_.size(0))
        top5.update(acc5, input_.size(0))

        # compute gradient and do SGD step
        if i % batch_multiplier == 0:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            loss.backward()

        optimizer_scheduler.step(float(i) / len(train_loader))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.print_freq == 0:
            logger.info('{rank}: '
                        'Epoch: [{0}][{1}/{2}] '
                        'Lr: {3:.3} '
                        'Wd: {4:.3} '
                        'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
                        'Data: {data_time.val:.3f} ({data_time.avg:.3f}) '
                        'CE_loss: {ce_loss.val:.4f} ({ce_loss.avg:.4f}) '
                        'KD_loss: {kd_loss.val:.4f} ({kd_loss.avg:.4f}) '
                        'Loss: {loss.val:.4f} ({loss.avg:.4f}) '
                        'Acc@1: {top1.val:.3f} ({top1.avg:.3f}) '
                        'Acc@5: {top5.val:.3f} ({top5.avg:.3f})'.format(
                            epoch,
                            i,
                            len(train_loader),
                            get_lr(optimizer),
                            get_wd(optimizer),
                            batch_time=batch_time,
                            data_time=data_time,
                            ce_loss=criterion_losses,
                            kd_loss=kd_losses_meter,
                            loss=losses,
                            top1=top1,
                            top5=top5,
                            rank='{}:'.format(config.rank)
                            if config.multiprocessing_distributed else ''))

        if is_main_process():
            global_step = len(train_loader) * epoch
            config.tb.add_scalar("train/learning_rate", get_lr(optimizer),
                                 i + global_step)
            config.tb.add_scalar("train/criterion_loss", criterion_losses.avg,
                                 i + global_step)
            config.tb.add_scalar("train/kd_loss", kd_losses_meter.avg,
                                 i + global_step)
            config.tb.add_scalar("train/loss", losses.avg, i + global_step)
            config.tb.add_scalar("train/top1", top1.avg, i + global_step)
            config.tb.add_scalar("train/top5", top5.avg, i + global_step)

            statistics = compression_ctrl.statistics(
                quickly_collected_only=True)
            for stat_name, stat_value in prepare_for_tensorboard(
                    statistics).items():
                config.tb.add_scalar('train/statistics/{}'.format(stat_name),
                                     stat_value, i + global_step)
Esempio n. 3
0
def train_staged(config,
                 compression_ctrl,
                 model,
                 criterion,
                 criterion_fn,
                 optimizer_scheduler,
                 model_name,
                 optimizer,
                 train_loader,
                 train_sampler,
                 val_loader,
                 kd_loss_calculator,
                 batch_multiplier,
                 best_acc1=0):
    best_compression_stage = CompressionStage.UNCOMPRESSED
    for epoch in range(config.start_epoch, config.epochs):
        # update compression scheduler state at the start of the epoch
        compression_ctrl.scheduler.epoch_step()

        if config.distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train_epoch_staged(train_loader, batch_multiplier, model, criterion,
                           criterion_fn, optimizer, optimizer_scheduler,
                           kd_loss_calculator, compression_ctrl, epoch, config)

        # compute compression algo statistics
        statistics = compression_ctrl.statistics()

        acc1 = best_acc1
        if epoch % config.test_every_n_epochs == 0:
            # evaluate on validation set
            # pylint: disable=E1123
            acc1, _, _ = validate(val_loader,
                                  model,
                                  criterion,
                                  config,
                                  epoch=epoch)

        compression_stage = compression_ctrl.compression_stage()
        # remember best acc@1, considering compression stage. If current acc@1 less then the best acc@1, checkpoint
        # still can be best if current compression stage is larger than the best one. Compression stages in ascending
        # order: UNCOMPRESSED, PARTIALLY_COMPRESSED, FULLY_COMPRESSED.
        is_best_by_accuracy = acc1 > best_acc1 and compression_stage == best_compression_stage
        is_best = is_best_by_accuracy or compression_stage > best_compression_stage
        best_acc1 = max(acc1, best_acc1)
        best_compression_stage = max(compression_stage, best_compression_stage)

        # statistics (e.g. portion of the enabled quantizers) is related to the finished epoch,
        # hence printing should happen before epoch_step, which may inform about state of the next epoch (e.g. next
        # portion of enabled quantizers)
        if is_main_process():
            logger.info(statistics.to_str())

        optimizer_scheduler.epoch_step()

        if is_main_process():
            checkpoint_path = osp.join(config.checkpoint_save_dir,
                                       get_name(config) + '_last.pth')
            checkpoint = {
                'epoch':
                epoch + 1,
                'arch':
                model_name,
                MODEL_STATE_ATTR:
                model.state_dict(),
                COMPRESSION_STATE_ATTR:
                compression_ctrl.get_compression_state(),
                'original_model_state_dict':
                kd_loss_calculator.original_model.state_dict(),
                'best_acc1':
                best_acc1,
                'optimizer':
                optimizer.state_dict(),
                'optimizer_scheduler':
                optimizer_scheduler.state_dict()
            }

            torch.save(checkpoint, checkpoint_path)
            make_additional_checkpoints(checkpoint_path, is_best, epoch + 1,
                                        config)

            for key, value in prepare_for_tensorboard(statistics).items():
                config.mlflow.safe_call(
                    'log_metric', 'compression/statistics/{0}'.format(key),
                    value, epoch)
                config.tb.add_scalar("compression/statistics/{0}".format(key),
                                     value,
                                     len(train_loader) * epoch)
Esempio n. 4
0
def staged_quantization_main_worker(current_gpu, config):
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)

    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    set_seed(config)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(config.device)

    model_name = config['model']
    is_inception = 'inception' in model_name
    train_criterion_fn = inception_criterion_fn if is_inception else default_criterion_fn

    train_loader = train_sampler = val_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path
    nncf_config = config.nncf_config

    pretrained = is_pretrained_model_requested(config)
    is_export_only = 'export' in config.mode and (
        'train' not in config.mode and 'test' not in config.mode)

    if is_export_only:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        # Data loading code
        train_dataset, val_dataset = create_datasets(config)
        train_loader, train_sampler, val_loader, init_loader = create_data_loaders(
            config, train_dataset, val_dataset)

        def autoq_eval_fn(model, eval_loader):
            _, top5, _ = validate(eval_loader, model, criterion, config)
            return top5

        nncf_config = register_default_init_args(
            nncf_config,
            init_loader,
            criterion=criterion,
            criterion_fn=train_criterion_fn,
            autoq_eval_fn=autoq_eval_fn,
            val_loader=val_loader,
            device=config.device)

    # create model
    model_name = config['model']
    model = load_model(model_name,
                       pretrained=pretrained,
                       num_classes=config.get('num_classes', 1000),
                       model_params=config.get('model_params'),
                       weights_path=config.get('weights'))
    original_model = copy.deepcopy(model)

    model.to(config.device)

    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
    model_state_dict, compression_state = extract_model_and_compression_states(
        resuming_checkpoint)
    compression_ctrl, model = create_compressed_model(model, nncf_config,
                                                      compression_state)
    if model_state_dict is not None:
        load_state(model, model_state_dict, is_resume=True)

    if not isinstance(compression_ctrl,
                      (BinarizationController, QuantizationController)):
        raise RuntimeError(
            "The stage quantization sample worker may only be run with the binarization and quantization algorithms!"
        )

    model, _ = prepare_model_for_execution(model, config)
    original_model.to(config.device)

    if config.distributed:
        compression_ctrl.distributed()

    params_to_optimize = model.parameters()

    compression_config = config['compression']
    quantization_config = compression_config if isinstance(
        compression_config, dict) else compression_config[0]
    optimizer = get_quantization_optimizer(params_to_optimize,
                                           quantization_config)
    optimizer_scheduler = PolyLRDropScheduler(optimizer, quantization_config)
    kd_loss_calculator = KDLossCalculator(original_model)

    best_acc1 = 0
    # optionally resume from a checkpoint
    if resuming_checkpoint is not None and config.to_onnx is None:
        config.start_epoch = resuming_checkpoint['epoch']
        best_acc1 = resuming_checkpoint['best_acc1']
        kd_loss_calculator.original_model.load_state_dict(
            resuming_checkpoint['original_model_state_dict'])
        if 'train' in config.mode:
            optimizer.load_state_dict(resuming_checkpoint['optimizer'])
            optimizer_scheduler.load_state_dict(
                resuming_checkpoint['optimizer_scheduler'])
            logger.info(
                "=> loaded checkpoint '{}' (epoch: {}, best_acc1: {:.3f})".
                format(resuming_checkpoint_path, resuming_checkpoint['epoch'],
                       best_acc1))
        else:
            logger.info(
                "=> loaded checkpoint '{}'".format(resuming_checkpoint_path))

    log_common_mlflow_params(config)

    if is_export_only:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if config.execution_mode != ExecutionMode.CPU_ONLY:
        cudnn.benchmark = True

    if is_main_process():
        statistics = compression_ctrl.statistics()
        logger.info(statistics.to_str())

    if 'train' in config.mode:
        batch_multiplier = (quantization_config.get("params", {})).get(
            "batch_multiplier", 1)
        train_staged(config, compression_ctrl, model, criterion,
                     train_criterion_fn, optimizer_scheduler, model_name,
                     optimizer, train_loader, train_sampler, val_loader,
                     kd_loss_calculator, batch_multiplier, best_acc1)

    if 'test' in config.mode:
        validate(val_loader, model, criterion, config)

    if 'export' in config.mode:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
Esempio n. 5
0
def train(net, compression_ctrl, train_data_loader, test_data_loader,
          criterion, optimizer, config, lr_scheduler):
    net.train()
    loc_loss = 0
    conf_loss = 0

    epoch_size = len(train_data_loader)
    logger.info('Training {} on {} dataset...'.format(
        config.model, train_data_loader.dataset.name))

    best_mAp = 0
    best_compression_stage = CompressionStage.UNCOMPRESSED
    test_freq_in_epochs = config.test_interval
    if config.test_interval is None:
        test_freq_in_epochs = 1

    max_epochs = config['epochs']

    for epoch in range(config.start_epoch, max_epochs):
        compression_ctrl.scheduler.epoch_step(epoch)

        train_epoch(compression_ctrl, net, config, train_data_loader,
                    criterion, optimizer, epoch_size, epoch, loc_loss,
                    conf_loss)

        if is_main_process():
            logger.info(compression_ctrl.statistics().to_str())

        compression_stage = compression_ctrl.compression_stage()
        is_best = False
        if (epoch + 1) % test_freq_in_epochs == 0:
            with torch.no_grad():
                net.eval()
                mAP = test_net(net,
                               config.device,
                               test_data_loader,
                               distributed=config.multiprocessing_distributed)
                is_best_by_mAP = mAP > best_mAp and compression_stage == best_compression_stage
                is_best = is_best_by_mAP or compression_stage > best_compression_stage
                if is_best:
                    best_mAp = mAP
                best_compression_stage = max(compression_stage,
                                             best_compression_stage)
                if isinstance(lr_scheduler, ReduceLROnPlateau):
                    lr_scheduler.step(mAP)
                net.train()

        if is_on_first_rank(config):
            logger.info('Saving state, epoch: {}'.format(epoch))

            checkpoint_file_path = osp.join(
                config.checkpoint_save_dir,
                "{}_last.pth".format(get_name(config)))
            torch.save(
                {
                    MODEL_STATE_ATTR:
                    net.state_dict(),
                    COMPRESSION_STATE_ATTR:
                    compression_ctrl.get_compression_state(),
                    'optimizer':
                    optimizer.state_dict(),
                    'epoch':
                    epoch,
                }, str(checkpoint_file_path))
            make_additional_checkpoints(checkpoint_file_path,
                                        is_best=is_best,
                                        epoch=epoch + 1,
                                        config=config)

        # Learning rate scheduling should be applied after optimizer’s update
        if not isinstance(lr_scheduler, ReduceLROnPlateau):
            lr_scheduler.step(epoch)

        compression_ctrl.scheduler.epoch_step(epoch)
        if is_main_process():
            statistics = compression_ctrl.statistics()
            logger.info(statistics.to_str())

        compression_stage = compression_ctrl.compression_stage()
        is_best = False
        if (epoch + 1) % test_freq_in_epochs == 0:
            with torch.no_grad():
                net.eval()
                mAP = test_net(net,
                               config.device,
                               test_data_loader,
                               distributed=config.multiprocessing_distributed)
                is_best_by_mAP = mAP > best_mAp and compression_stage == best_compression_stage
                is_best = is_best_by_mAP or compression_stage > best_compression_stage
                if is_best:
                    best_mAp = mAP
                best_compression_stage = max(compression_stage,
                                             best_compression_stage)
                if isinstance(lr_scheduler, ReduceLROnPlateau):
                    lr_scheduler.step(mAP)
                net.train()

        if is_on_first_rank(config):
            logger.info('Saving state, epoch: {}'.format(epoch))

            checkpoint_file_path = osp.join(
                config.checkpoint_save_dir,
                "{}_last.pth".format(get_name(config)))
            torch.save(
                {
                    MODEL_STATE_ATTR:
                    net.state_dict(),
                    COMPRESSION_STATE_ATTR:
                    compression_ctrl.get_compression_state(),
                    'optimizer':
                    optimizer.state_dict(),
                    'epoch':
                    epoch,
                }, str(checkpoint_file_path))
            make_additional_checkpoints(checkpoint_file_path,
                                        is_best=is_best,
                                        epoch=epoch + 1,
                                        config=config)

        # Learning rate scheduling should be applied after optimizer’s update
        if not isinstance(lr_scheduler, ReduceLROnPlateau):
            lr_scheduler.step(epoch)

    if config.metrics_dump is not None:
        write_metrics(best_mAp, config.metrics_dump)
Esempio n. 6
0
def main_worker(current_gpu, config):
    #################################
    # Setup experiment environment
    #################################
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)
    if is_on_first_rank(config):
        configure_logging(logger, config)
        print_args(config)

    set_seed(config)

    config.start_iter = 0
    nncf_config = config.nncf_config
    ##########################
    # Prepare metrics log file
    ##########################

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    ###########################
    # Criterion
    ###########################

    criterion = MultiBoxLoss(config,
                             config['num_classes'],
                             overlap_thresh=0.5,
                             prior_for_matching=True,
                             bkg_label=0,
                             neg_mining=True,
                             neg_pos=3,
                             neg_overlap=0.5,
                             encode_target=False,
                             device=config.device)

    train_data_loader = test_data_loader = None
    resuming_checkpoint_path = config.resuming_checkpoint_path

    ###########################
    # Prepare data
    ###########################

    pretrained = is_pretrained_model_requested(config)

    is_export_only = 'export' in config.mode and (
        'train' not in config.mode and 'test' not in config.mode)
    if is_export_only:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        test_data_loader, train_data_loader, init_data_loader = create_dataloaders(
            config)

        def criterion_fn(model_outputs, target, criterion):
            loss_l, loss_c = criterion(model_outputs, target)
            return loss_l + loss_c

        def autoq_test_fn(model, eval_loader):
            # RL is maximization, change the loss polarity
            return -1 * test_net(model,
                                 config.device,
                                 eval_loader,
                                 distributed=config.distributed,
                                 loss_inference=True,
                                 criterion=criterion)

        def model_eval_fn(model):
            model.eval()
            mAP = test_net(model,
                           config.device,
                           test_data_loader,
                           distributed=config.distributed,
                           criterion=criterion)
            return mAP

        nncf_config = register_default_init_args(nncf_config,
                                                 init_data_loader,
                                                 criterion=criterion,
                                                 criterion_fn=criterion_fn,
                                                 autoq_eval_fn=autoq_test_fn,
                                                 val_loader=test_data_loader,
                                                 model_eval_fn=model_eval_fn,
                                                 device=config.device)

    ##################
    # Prepare model
    ##################
    resuming_checkpoint_path = config.resuming_checkpoint_path

    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
    compression_ctrl, net = create_model(config, resuming_checkpoint)
    if config.distributed:
        config.batch_size //= config.ngpus_per_node
        config.workers //= config.ngpus_per_node
        compression_ctrl.distributed()

    ###########################
    # Optimizer
    ###########################

    params_to_optimize = get_parameter_groups(net, config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    #################################
    # Load additional checkpoint data
    #################################

    if resuming_checkpoint_path is not None and 'train' in config.mode:
        optimizer.load_state_dict(
            resuming_checkpoint.get('optimizer', optimizer.state_dict()))
        config.start_epoch = resuming_checkpoint.get('epoch', 0) + 1

    log_common_mlflow_params(config)

    if is_export_only:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if is_main_process():
        statistics = compression_ctrl.statistics()
        logger.info(statistics.to_str())

    if 'train' in config.mode and is_accuracy_aware_training(config):
        # validation function that returns the target metric value
        # pylint: disable=E1123
        def validate_fn(model, epoch):
            model.eval()
            mAP = test_net(model,
                           config.device,
                           test_data_loader,
                           distributed=config.distributed)
            model.train()
            return mAP

        # training function that trains the model for one epoch (full training dataset pass)
        # it is assumed that all the NNCF-related methods are properly called inside of
        # this function (like e.g. the step and epoch_step methods of the compression scheduler)
        def train_epoch_fn(compression_ctrl, model, epoch, optimizer,
                           **kwargs):
            loc_loss = 0
            conf_loss = 0
            epoch_size = len(train_data_loader)
            train_epoch(compression_ctrl, model, config, train_data_loader,
                        criterion, optimizer, epoch_size, epoch, loc_loss,
                        conf_loss)

        # function that initializes optimizers & lr schedulers to start training
        def configure_optimizers_fn():
            params_to_optimize = get_parameter_groups(net, config)
            optimizer, lr_scheduler = make_optimizer(params_to_optimize,
                                                     config)
            return optimizer, lr_scheduler

        acc_aware_training_loop = create_accuracy_aware_training_loop(
            nncf_config, compression_ctrl)
        net = acc_aware_training_loop.run(
            net,
            train_epoch_fn=train_epoch_fn,
            validate_fn=validate_fn,
            configure_optimizers_fn=configure_optimizers_fn,
            tensorboard_writer=config.tb,
            log_dir=config.log_dir)
    elif 'train' in config.mode:
        train(net, compression_ctrl, train_data_loader, test_data_loader,
              criterion, optimizer, config, lr_scheduler)

    if 'test' in config.mode:
        with torch.no_grad():
            net.eval()
            if config['ssd_params'].get('loss_inference', False):
                model_loss = test_net(net,
                                      config.device,
                                      test_data_loader,
                                      distributed=config.distributed,
                                      loss_inference=True,
                                      criterion=criterion)
                logger.info("Final model loss: {:.3f}".format(model_loss))
            else:
                mAp = test_net(net,
                               config.device,
                               test_data_loader,
                               distributed=config.distributed)
                if config.metrics_dump is not None:
                    write_metrics(mAp, config.metrics_dump)

    if 'export' in config.mode:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
Esempio n. 7
0
 def _is_enabled(self):
     return self.is_suitable_mode and is_main_process()
Esempio n. 8
0
def main_worker(current_gpu, config):
    configure_device(current_gpu, config)
    config.mlflow = SafeMLFLow(config)
    if is_main_process():
        configure_logging(logger, config)
        print_args(config)

    set_seed(config)
    logger.info(config)

    dataset = get_dataset(config.dataset)
    color_encoding = dataset.color_encoding
    num_classes = len(color_encoding)

    if config.metrics_dump is not None:
        write_metrics(0, config.metrics_dump)

    train_loader = val_loader = criterion = None
    resuming_checkpoint_path = config.resuming_checkpoint_path

    nncf_config = config.nncf_config

    pretrained = is_pretrained_model_requested(config)

    def criterion_fn(model_outputs, target, criterion_):
        labels, loss_outputs, _ = \
            loss_funcs.do_model_specific_postprocessing(config.model, target, model_outputs)
        return criterion_(loss_outputs, labels)

    is_export_only = 'export' in config.mode and (
        'train' not in config.mode and 'test' not in config.mode)
    if is_export_only:
        assert pretrained or (resuming_checkpoint_path is not None)
    else:
        loaders, w_class = load_dataset(dataset, config)
        train_loader, val_loader, init_loader = loaders
        criterion = get_criterion(w_class, config)

        def autoq_test_fn(model, eval_loader):
            return test(model, eval_loader, criterion, color_encoding, config)

        model_eval_fn = functools.partial(autoq_test_fn,
                                          eval_loader=val_loader)

        nncf_config = register_default_init_args(nncf_config,
                                                 init_loader,
                                                 criterion=criterion,
                                                 criterion_fn=criterion_fn,
                                                 autoq_eval_fn=autoq_test_fn,
                                                 val_loader=val_loader,
                                                 model_eval_fn=model_eval_fn,
                                                 device=config.device)

    model = load_model(config.model,
                       pretrained=pretrained,
                       num_classes=num_classes,
                       model_params=config.get('model_params', {}),
                       weights_path=config.get('weights'))

    model.to(config.device)

    resuming_checkpoint = None
    if resuming_checkpoint_path is not None:
        resuming_checkpoint = load_resuming_checkpoint(
            resuming_checkpoint_path)
    model_state_dict, compression_state = extract_model_and_compression_states(
        resuming_checkpoint)
    compression_ctrl, model = create_compressed_model(model, nncf_config,
                                                      compression_state)
    if model_state_dict is not None:
        load_state(model, model_state_dict, is_resume=True)
    model, model_without_dp = prepare_model_for_execution(model, config)

    if config.distributed:
        compression_ctrl.distributed()

    log_common_mlflow_params(config)

    if is_export_only:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
        return

    if is_main_process():
        statistics = compression_ctrl.statistics()
        logger.info(statistics.to_str())

    if is_accuracy_aware_training(config) and 'train' in config.mode:

        def validate_fn(model, epoch):
            return test(model, val_loader, criterion, color_encoding, config)

        # training function that trains the model for one epoch (full training dataset pass)
        # it is assumed that all the NNCF-related methods are properly called inside of
        # this function (like e.g. the step and epoch_step methods of the compression scheduler)
        def train_epoch_fn(compression_ctrl, model, optimizer, **kwargs):
            ignore_index = None
            ignore_unlabeled = config.get("ignore_unlabeled", True)
            if ignore_unlabeled and ('unlabeled' in color_encoding):
                ignore_index = list(color_encoding).index('unlabeled')
            metric = IoU(len(color_encoding), ignore_index=ignore_index)
            train_obj = Train(model, train_loader, optimizer, criterion,
                              compression_ctrl, metric, config.device,
                              config.model)
            train_obj.run_epoch(config.print_step)

        # function that initializes optimizers & lr schedulers to start training
        def configure_optimizers_fn():
            optim_config = config.get('optimizer', {})
            optim_params = optim_config.get('optimizer_params', {})
            lr = optim_params.get("lr", 1e-4)
            params_to_optimize = get_params_to_optimize(
                model_without_dp, lr * 10, config)
            optimizer, lr_scheduler = make_optimizer(params_to_optimize,
                                                     config)
            return optimizer, lr_scheduler

        acc_aware_training_loop = create_accuracy_aware_training_loop(
            config, compression_ctrl)
        model = acc_aware_training_loop.run(
            model,
            train_epoch_fn=train_epoch_fn,
            validate_fn=validate_fn,
            configure_optimizers_fn=configure_optimizers_fn,
            tensorboard_writer=config.tb,
            log_dir=config.log_dir)

    elif 'train' in config.mode:
        train(model, model_without_dp, compression_ctrl, train_loader,
              val_loader, criterion, color_encoding, config,
              resuming_checkpoint)

    if 'test' in config.mode:
        logger.info(model)
        model_parameters = filter(lambda p: p.requires_grad,
                                  model.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        logger.info("Trainable argument count:{params}".format(params=params))
        model = model.to(config.device)
        test(model, val_loader, criterion, color_encoding, config)

    if 'export' in config.mode:
        compression_ctrl.export_model(config.to_onnx)
        logger.info("Saved to {}".format(config.to_onnx))
Esempio n. 9
0
def train(model, model_without_dp, compression_ctrl, train_loader, val_loader,
          criterion, class_encoding, config, resuming_checkpoint):
    logger.info("\nTraining...\n")

    # Check if the network architecture is correct
    logger.info(model)

    optim_config = config.get('optimizer', {})
    optim_params = optim_config.get('optimizer_params', {})
    lr = optim_params.get("lr", 1e-4)

    params_to_optimize = get_params_to_optimize(model_without_dp, lr * 10,
                                                config)
    optimizer, lr_scheduler = make_optimizer(params_to_optimize, config)

    # Evaluation metric

    ignore_index = None
    ignore_unlabeled = config.get("ignore_unlabeled", True)
    if ignore_unlabeled and ('unlabeled' in class_encoding):
        ignore_index = list(class_encoding).index('unlabeled')

    metric = IoU(len(class_encoding), ignore_index=ignore_index)

    best_miou = -1
    best_compression_stage = CompressionStage.UNCOMPRESSED
    # Optionally resume from a checkpoint
    if resuming_checkpoint is not None:
        if optimizer is not None:
            optimizer.load_state_dict(resuming_checkpoint['optimizer'])
        start_epoch = resuming_checkpoint['epoch']
        best_miou = resuming_checkpoint['miou']

        logger.info("Resuming from model: Start epoch = {0} "
                    "| Best mean IoU = {1:.4f}".format(start_epoch, best_miou))
        config.start_epoch = start_epoch

    # Start Training
    train_obj = Train(model, train_loader, optimizer, criterion,
                      compression_ctrl, metric, config.device, config.model)
    val_obj = Test(model, val_loader, criterion, metric, config.device,
                   config.model)

    for epoch in range(config.start_epoch, config.epochs):
        compression_ctrl.scheduler.epoch_step()
        logger.info(">>>> [Epoch: {0:d}] Training".format(epoch))

        if config.distributed:
            train_loader.sampler.set_epoch(epoch)

        epoch_loss, (iou, miou) = train_obj.run_epoch(config.print_step)
        if not isinstance(lr_scheduler, ReduceLROnPlateau):
            # Learning rate scheduling should be applied after optimizer’s update
            lr_scheduler.step(epoch)

        logger.info(
            ">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
            format(epoch, epoch_loss, miou))

        if is_main_process():
            config.tb.add_scalar("train/loss", epoch_loss, epoch)
            config.tb.add_scalar("train/mIoU", miou, epoch)
            config.tb.add_scalar("train/learning_rate",
                                 optimizer.param_groups[0]['lr'], epoch)
            config.tb.add_scalar("train/compression_loss",
                                 compression_ctrl.loss(), epoch)

            statistics = compression_ctrl.statistics(
                quickly_collected_only=True)
            for key, value in prepare_for_tensorboard(statistics).items():
                config.tb.add_scalar("compression/statistics/{0}".format(key),
                                     value, epoch)

        if (epoch + 1) % config.save_freq == 0 or epoch + 1 == config.epochs:
            logger.info(">>>> [Epoch: {0:d}] Validation".format(epoch))

            loss, (iou, miou) = val_obj.run_epoch(config.print_step)

            logger.info(
                ">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                format(epoch, loss, miou))

            if is_main_process():
                config.tb.add_scalar("val/mIoU", miou, epoch)
                config.tb.add_scalar("val/loss", loss, epoch)
                for i, (key,
                        class_iou) in enumerate(zip(class_encoding.keys(),
                                                    iou)):
                    config.tb.add_scalar(
                        "{}/mIoU_Cls{}_{}".format(config.dataset, i, key),
                        class_iou, epoch)

            compression_stage = compression_ctrl.compression_stage()
            is_best_by_miou = miou > best_miou and compression_stage == best_compression_stage
            is_best = is_best_by_miou or compression_stage > best_compression_stage
            if is_best:
                best_miou = miou
            best_compression_stage = max(compression_stage,
                                         best_compression_stage)

            if config.metrics_dump is not None:
                write_metrics(best_miou, config.metrics_dump)

            if isinstance(lr_scheduler, ReduceLROnPlateau):
                # Learning rate scheduling should be applied after optimizer’s update
                lr_scheduler.step(best_miou)

            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == config.epochs or is_best:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    logger.info("{0}: {1:.4f}".format(key, class_iou))

            # Save the model if it's the best thus far
            if is_main_process():
                checkpoint_path = save_checkpoint(model, compression_ctrl,
                                                  optimizer, epoch, best_miou,
                                                  config)
                make_additional_checkpoints(checkpoint_path, is_best, epoch,
                                            config)
                statistics = compression_ctrl.statistics()
                logger.info(statistics.to_str())

    return model