예제 #1
0
def validate(epoch, calib_loader, val_loader, criterion, val_meters,
             model_wrapper, ema, phase):
    """Calibrate and validate."""
    assert phase in ['test', 'val']
    model_eval_wrapper = mc.get_ema_model(ema, model_wrapper)

    # bn_calibration
    if FLAGS.get('bn_calibration', False):
        if not FLAGS.use_distributed:
            logging.warning(
                'Only GPU0 is used when calibration when use DataParallel')
        with torch.no_grad():
            _ = run_one_epoch(epoch,
                              calib_loader,
                              model_eval_wrapper,
                              criterion,
                              None,
                              None,
                              None,
                              None,
                              val_meters,
                              max_iter=FLAGS.bn_calibration_steps,
                              phase='bn_calibration')
        if FLAGS.use_distributed:
            udist.allreduce_bn(model_eval_wrapper)

    # val
    with torch.no_grad():
        results = run_one_epoch(epoch,
                                val_loader,
                                model_eval_wrapper,
                                criterion,
                                None,
                                None,
                                None,
                                None,
                                val_meters,
                                phase=phase)
    summary_bn(model_eval_wrapper, phase)
    return results, model_eval_wrapper
예제 #2
0
파일: train.py 프로젝트: dingmyu/HR-NAS
def validate(epoch,
             calib_loader,
             val_loader,
             criterion,
             val_meters,
             model_wrapper,
             ema,
             phase,
             segval=None,
             val_set=None):
    """Calibrate and validate."""
    assert phase in ['test', 'val']
    model_eval_wrapper = mc.get_ema_model(ema, model_wrapper)

    # bn_calibration
    if FLAGS.prune_params['method'] is not None:
        if FLAGS.get('bn_calibration', False):
            if not FLAGS.use_distributed:
                logging.warning(
                    'Only GPU0 is used when calibration when use DataParallel')
            with torch.no_grad():
                _ = run_one_epoch(epoch,
                                  calib_loader,
                                  model_eval_wrapper,
                                  criterion,
                                  None,
                                  None,
                                  None,
                                  None,
                                  val_meters,
                                  max_iter=FLAGS.bn_calibration_steps,
                                  phase='bn_calibration')
            if FLAGS.use_distributed:
                udist.allreduce_bn(model_eval_wrapper)

    # val
    with torch.no_grad():
        if FLAGS.model_kwparams.task == 'segmentation':
            if FLAGS.dataset == 'coco':
                results = 0
                if udist.is_master():
                    results = keypoint_val(val_set, val_loader,
                                           model_eval_wrapper.module,
                                           criterion)
            else:
                assert segval is not None
                results = segval.run(
                    epoch, val_loader, model_eval_wrapper.module
                    if FLAGS.single_gpu_test else model_eval_wrapper, FLAGS)
        else:
            results = run_one_epoch(epoch,
                                    val_loader,
                                    model_eval_wrapper,
                                    criterion,
                                    None,
                                    None,
                                    None,
                                    None,
                                    val_meters,
                                    phase=phase)
    summary_bn(model_eval_wrapper, phase)
    return results, model_eval_wrapper