예제 #1
0
파일: test_zoo.py 프로젝트: PIlotcnc/new
def test_load_model_from_stub(stub, model_args, other_args):
    model = Zoo.load_model_from_stub(stub, **other_args)
    model.download(overwrite=True)
    for key in model_args:
        if key and hasattr(model, key):
            assert getattr(model, key) == model_args[key]
    shutil.rmtree(model.dir_path)
예제 #2
0
def load_data(data_path: str, ) -> List[List[numpy.ndarray]]:
    """
    Loads data from given sparseZoo stub or directory with .npz files
    :param data_path: directory path to .npz files to load or SparseZoo stub
    :return: List of loaded npz files
    """

    if data_path.startswith("zoo:"):
        data_dir = Zoo.load_model_from_stub(
            data_path).data_inputs.downloaded_path()
    else:
        data_dir = data_path
        data_files = os.listdir(data_dir)
        if any(".npz" not in file_name for file_name in data_files):
            raise RuntimeError(
                f"All files in data directory {data_dir} must have a .npz extension "
                f"found {[name for name in data_files if '.npz' not in name]}")

    samples = load_numpy_list(data_dir)
    # unwrap unloaded numpy files
    samples = [
        load_numpy(sample) if isinstance(sample, str) else sample
        for sample in samples
    ]

    processed_samples = []
    for idx, sample in enumerate(samples):
        sample = list(sample.values())
        processed_samples.append(sample)

    return processed_samples
예제 #3
0
def model_to_path(model: Union[str, Model, File]) -> str:
    """
    Deals with the various forms a model can take. Either an ONNX file,
    a SparseZoo model stub prefixed by 'zoo:', a SparseZoo Model object,
    or a SparseZoo ONNX File object that defines the neural network
    """
    if not model:
        raise ValueError(
            "model must be a path, sparsezoo.Model, or sparsezoo.File")

    if isinstance(model, str) and model.startswith("zoo:"):
        # load SparseZoo Model from stub
        if sparsezoo_import_error is not None:
            raise sparsezoo_import_error
        model = Zoo.load_model_from_stub(model)

    if Model is not object and isinstance(model, Model):
        # default to the main onnx file for the model
        model = model.onnx_file.downloaded_path()
    elif File is not object and isinstance(model, File):
        # get the downloaded_path -- will auto download if not on local system
        model = model.downloaded_path()

    if not isinstance(model, str):
        raise ValueError("unsupported type for model: {}".format(type(model)))

    if not os.path.exists(model):
        raise ValueError("model path must exist: given {}".format(model))

    return model
예제 #4
0
def modify_yolo_onnx_input_shape(
    model_path: str, image_shape: Tuple[int]
) -> Tuple[str, Optional[NamedTemporaryFile]]:
    """
    Creates a new YOLOv3 ONNX model from the given path that accepts the given input
    shape. If the given model already has the given input shape no modifications are
    made. Uses a tempfile to store the modified model file.

    :param model_path: file path to YOLOv3 ONNX model or SparseZoo stub of the model
        to be loaded
    :param image_shape: 2-tuple of the image shape to resize this yolo model to
    :return: filepath to an onnx model reshaped to the given input shape will be the
        original path if the shape is the same.  Additionally returns the
        NamedTemporaryFile for managing the scope of the object for file deletion
    """
    original_model_path = model_path
    if model_path.startswith("zoo:"):
        # load SparseZoo Model from stub
        model = Zoo.load_model_from_stub(model_path)
        model_path = model.onnx_file.downloaded_path()
        print(f"Downloaded {original_model_path} to {model_path}")

    model = onnx.load(model_path)
    model_input = model.graph.input[0]

    initial_x = get_tensor_dim_shape(model_input, 2)
    initial_y = get_tensor_dim_shape(model_input, 3)

    if not (isinstance(initial_x, int) and isinstance(initial_y, int)):
        return model_path, None  # model graph does not have static integer input shape

    if (initial_x, initial_y) == tuple(image_shape):
        return model_path, None  # no shape modification needed

    scale_x = initial_x / image_shape[0]
    scale_y = initial_y / image_shape[1]
    set_tensor_dim_shape(model_input, 2, image_shape[0])
    set_tensor_dim_shape(model_input, 3, image_shape[1])

    for model_output in model.graph.output:
        output_x = get_tensor_dim_shape(model_output, 2)
        output_y = get_tensor_dim_shape(model_output, 3)
        set_tensor_dim_shape(model_output, 2, int(output_x / scale_x))
        set_tensor_dim_shape(model_output, 3, int(output_y / scale_y))

    tmp_file = NamedTemporaryFile()  # file will be deleted after program exit
    onnx.save(model, tmp_file.name)

    print(
        f"Overwriting original model shape {(initial_x, initial_y)} to {image_shape}\n"
        f"Original model path: {original_model_path}, new temporary model saved to "
        f"{tmp_file.name}"
    )

    return tmp_file.name, tmp_file
예제 #5
0
def _load_model(args) -> Tuple[Any, List[str]]:
    if args.engine == ORT_ENGINE and ort_error is not None:
        raise ort_error

    # validation
    if (args.num_cores is not None and args.engine == ORT_ENGINE
            and onnxruntime.__version__ < "1.7"):
        raise ValueError(
            "overriding default num_cores not supported for onnxruntime < 1.7.0. "
            "If using an older build with OpenMP, try setting the OMP_NUM_THREADS "
            "environment variable")

    # load model from sparsezoo if necessary
    if args.model_filepath.startswith("zoo:"):
        zoo_model = Zoo.load_model_from_stub(args.model_filepath)
        downloaded_path = zoo_model.onnx_file.downloaded_path()
        print(
            f"downloaded sparsezoo model {args.model_filepath} to {downloaded_path}"
        )
        args.model_filepath = downloaded_path

    # scale static ONNX graph to desired image shape
    input_names = []
    if args.engine in [DEEPSPARSE_ENGINE, ORT_ENGINE]:
        args.model_filepath, input_names, _ = overwrite_transformer_onnx_model_inputs(
            args.model_filepath,
            batch_size=args.batch_size,
            max_length=args.max_sequence_length,
        )

    # load model
    if args.engine == DEEPSPARSE_ENGINE:
        print(f"Compiling deepsparse model for {args.model_filepath}")
        model = compile_model(args.model_filepath, args.batch_size,
                              args.num_cores)
        print(f"Engine info: {model}")
    elif args.engine == ORT_ENGINE:
        print(f"loading onnxruntime model for {args.model_filepath}")

        sess_options = onnxruntime.SessionOptions()
        if args.num_cores is not None:
            sess_options.intra_op_num_threads = args.num_cores
        sess_options.log_severity_level = 3
        sess_options.graph_optimization_level = (
            onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL)
        model = onnxruntime.InferenceSession(args.model_filepath,
                                             sess_options=sess_options)

    return model, input_names
예제 #6
0
파일: test_zoo.py 프로젝트: PIlotcnc/new
def test_search_sparse_recipes_from_stub(model_stub, other_args):
    model = Zoo.load_model_from_stub(model_stub, **other_args)
    recipes = Zoo.search_sparse_recipes(model_stub)
    assert len(recipes) > 0

    for recipe in recipes:
        assert recipe
        assert recipe.model_metadata.domain == model.domain
        assert recipe.model_metadata.sub_domain == model.sub_domain
        assert recipe.model_metadata.architecture == model.architecture
        assert recipe.model_metadata.sub_architecture == model.sub_architecture
        assert recipe.model_metadata.framework == model.framework
        assert recipe.model_metadata.repo == model.repo
        assert recipe.model_metadata.dataset == model.dataset
        assert recipe.model_metadata.training_scheme == model.training_scheme
예제 #7
0
def get_onnx_path_and_configs(
    model_path: str, ) -> Tuple[str, Optional[str], Optional[str]]:
    """
    :param model_path: path to onnx file, transformers sparsezoo stub,
        or directory containing `model.onnx`, `config.json`, and/or
        `tokenizer.json` files. If no `model.onnx` file is found in
        a model directory, an exception will be raised
    :return: tuple of ONNX file path, parent directory of config file
        if it exists, and parent directory of tokenizer config file if it
        exists. (Parent directories returned instead of absolute path
        for compatibility with transformers .from_pretrained() method)
    """
    if os.path.isfile(model_path):
        return model_path, None, None

    config_path = None
    tokenizer_path = None
    if os.path.isdir(model_path):
        model_files = os.listdir(model_path)

        if _MODEL_DIR_ONNX_NAME not in model_files:
            raise ValueError(
                f"{_MODEL_DIR_ONNX_NAME} not found in transformers model directory "
                f"{model_path}. Be sure that an export of the model is written to "
                f"{os.path.join(model_path, _MODEL_DIR_ONNX_NAME)}")
        onnx_path = os.path.join(model_path, _MODEL_DIR_ONNX_NAME)

        if _MODEL_DIR_CONFIG_NAME in model_files:
            config_path = model_path
        if _MODEL_DIR_TOKENIZER_NAME in model_files:
            tokenizer_path = model_path

    elif model_path.startswith("zoo:"):
        zoo_model = Zoo.load_model_from_stub(model_path)
        onnx_path = zoo_model.onnx_file.downloaded_path()

        for framework_file in zoo_model.framework_files:
            if framework_file.display_name == _MODEL_DIR_CONFIG_NAME:
                config_path = _get_file_parent(
                    framework_file.downloaded_path())
            if "tokenizer" in framework_file.display_name:
                tokenizer_path = _get_file_parent(
                    framework_file.downloaded_path())
    else:
        raise ValueError(
            f"model_path {model_path} is not a valid file, directory, or zoo stub"
        )
    return onnx_path, config_path, tokenizer_path
예제 #8
0
def fix_onnx_input_shape(
    model_path: str,
    image_shape: Optional[Tuple[int]],
) -> Tuple[str, Optional[NamedTemporaryFile]]:
    """
    Creates a new ONNX model from the given path that accepts the given input
    shape. If the given model already has the given input shape no modifications are
    made. Uses a tempfile to store the modified model file.

    :param model_path: file path to ONNX model or SparseZoo stub of the model
        to be loaded
    :param image_shape: 2-tuple of the image shape to resize this model to, or None if
        no resizing needed
    :return: filepath to an onnx model reshaped to the given input shape will be the
        original path if the shape is the same.  Additionally returns the
        NamedTemporaryFile for managing the scope of the object for file deletion.
        Additionally returns the image-shape to benchmark the new model with.
    """
    original_model_path = model_path
    if model_path.startswith("zoo:"):
        # load SparseZoo Model from stub
        model = Zoo.load_model_from_stub(model_path)
        model_path = model.onnx_file.downloaded_path()
        print(f"Downloaded {original_model_path} to {model_path}")

    model = onnx.load(model_path)
    model_input = model.graph.input[0]

    original_x = get_tensor_dim_shape(model_input, 2)
    original_y = get_tensor_dim_shape(model_input, 3)
    original_image_shape = (original_x, original_y)

    if image_shape is None or original_image_shape == tuple(image_shape):
        return model_path, None, original_image_shape  # no shape modification needed

    set_tensor_dim_shape(model_input, 2, image_shape[0])
    set_tensor_dim_shape(model_input, 3, image_shape[1])

    tmp_file = NamedTemporaryFile()  # file will be deleted after program exit
    onnx.save(model, tmp_file.name)

    print(
        f"Overwriting original model shape {original_image_shape} to {image_shape}\n"
        f"Original model path: {original_model_path}, new temporary model saved to "
        f"{tmp_file.name}")

    return tmp_file.name, tmp_file, image_shape
예제 #9
0
def _load_data(args, input_names) -> List[List[numpy.ndarray]]:
    if args.data_path.startswith("zoo:"):
        data_dir = Zoo.load_model_from_stub(
            args.data_path).data_inputs.downloaded_path()
    else:
        data_dir = args.data_path
        data_files = os.listdir(data_dir)
        if any(".npz" not in file_name for file_name in data_files):
            raise RuntimeError(
                f"All files in data directory {data_dir} must have a .npz extension "
                f"found {[name for name in data_files if '.npz' not in name]}")

    samples = load_numpy_list(data_dir)

    # unwrap unloaded numpy files
    samples = [
        load_numpy(sample) if isinstance(sample, str) else sample
        for sample in samples
    ]

    processed_samples = []
    warning_given = False
    for sample in samples:
        if not all(inp_name in sample for inp_name in
                   input_names) or len(input_names) != len(sample):
            if not warning_given:
                warnings.warn(
                    "input sample found whose input names do not match the model input "
                    "names, this may cause an exception during benchmarking")
                warning_given = True
            sample = list(sample.values())
        else:
            sample = [sample[inp_name] for inp_name in input_names]

        for idx, array in enumerate(sample):
            processed_array = numpy.zeros(
                [args.max_sequence_length, *array.shape[1:]],
                dtype=array.dtype,
            )
            if array.shape[0] < args.max_sequence_length:
                processed_array[:array.shape[0], ...] = array
            else:
                processed_array[:, ...] = array[:args.max_sequence_length, ...]
            sample[idx] = processed_array
        processed_samples.append(sample)
    return processed_samples
예제 #10
0
파일: train.py 프로젝트: joskid/sparseml
def main():
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
        _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
                     % (args.rank, args.world_size))
    else:
        _logger.info('Training with a single process on 1 GPUs.')
    assert args.rank >= 0

    # resolve AMP arguments based on PyTorch / Apex availability
    use_amp = None
    if args.amp:
        # for backwards compat, `--amp` arg tries apex before native amp
        if has_apex:
            args.apex_amp = True
        elif has_native_amp:
            args.native_amp = True
    if args.apex_amp and has_apex:
        use_amp = 'apex'
    elif args.native_amp and has_native_amp:
        use_amp = 'native'
    elif args.apex_amp or args.native_amp:
        _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
                        "Install NVIDA apex or upgrade to PyTorch 1.6")

    torch.manual_seed(args.seed + args.rank)

    ####################################################################################
    # Start - SparseML optional load weights from SparseZoo
    ####################################################################################
    if args.initial_checkpoint == "zoo":
        # Load checkpoint from base weights associated with given SparseZoo recipe
        if args.sparseml_recipe.startswith("zoo:"):
            args.initial_checkpoint = Zoo.download_recipe_base_framework_files(
                args.sparseml_recipe,
                extensions=[".pth.tar", ".pth"]
            )[0]
        else:
            raise ValueError(
                "Attempting to load weights from SparseZoo recipe, but not given a "
                "SparseZoo recipe stub.  When initial-checkpoint is set to 'zoo'. "
                "sparseml-recipe must start with 'zoo:' and be a SparseZoo model "
                f"stub. sparseml-recipe was set to {args.sparseml_recipe}"
            )
    elif args.initial_checkpoint.startswith("zoo:"):
        # Load weights from a SparseZoo model stub
        zoo_model = Zoo.load_model_from_stub(args.initial_checkpoint)
        args.initial_checkpoint = zoo_model.download_framework_files(extensions=[".pth"])
    ####################################################################################
    # End - SparseML optional load weights from SparseZoo
    ####################################################################################

    model = create_model(
        args.model,
        pretrained=args.pretrained,
        num_classes=args.num_classes,
        drop_rate=args.drop,
        drop_connect_rate=args.drop_connect,  # DEPRECATED, use drop_path
        drop_path_rate=args.drop_path,
        drop_block_rate=args.drop_block,
        global_pool=args.gp,
        bn_tf=args.bn_tf,
        bn_momentum=args.bn_momentum,
        bn_eps=args.bn_eps,
        scriptable=args.torchscript,
        checkpoint_path=args.initial_checkpoint)
    if args.num_classes is None:
        assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
        args.num_classes = model.num_classes  # FIXME handle model default vs config num_classes more elegantly

    if args.local_rank == 0:
        _logger.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel() for m in model.parameters()])))

    data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)

    # setup augmentation batch splits for contrastive loss or split bn
    num_aug_splits = 0
    if args.aug_splits > 0:
        assert args.aug_splits > 1, 'A split of 1 makes no sense'
        num_aug_splits = args.aug_splits

    # enable split bn (separate bn stats per batch-portion)
    if args.split_bn:
        assert num_aug_splits > 1 or args.resplit
        model = convert_splitbn_model(model, max(num_aug_splits, 2))

    # move model to GPU, enable channels last layout if set
    model.cuda()
    if args.channels_last:
        model = model.to(memory_format=torch.channels_last)

    # setup synchronized BatchNorm for distributed training
    if args.distributed and args.sync_bn:
        assert not args.split_bn
        if has_apex and use_amp != 'native':
            # Apex SyncBN preferred unless native amp is activated
            model = convert_syncbn_model(model)
        else:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        if args.local_rank == 0:
            _logger.info(
                'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
                'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')

    if args.torchscript:
        assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
        assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
        model = torch.jit.script(model)

    optimizer = create_optimizer(args, model)

    # setup automatic mixed-precision (AMP) loss scaling and op casting
    amp_autocast = suppress  # do nothing
    loss_scaler = None
    if use_amp == 'apex':
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        loss_scaler = ApexScaler()
        if args.local_rank == 0:
            _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
    elif use_amp == 'native':
        amp_autocast = torch.cuda.amp.autocast
        loss_scaler = NativeScaler()
        if args.local_rank == 0:
            _logger.info('Using native Torch AMP. Training in mixed precision.')
    else:
        if args.local_rank == 0:
            _logger.info('AMP not enabled. Training in float32.')

    # optionally resume from a checkpoint
    resume_epoch = None
    if args.resume:
        resume_epoch = resume_checkpoint(
            model, args.resume,
            optimizer=None if args.no_resume_opt else optimizer,
            loss_scaler=None if args.no_resume_opt else loss_scaler,
            log_info=args.local_rank == 0)

    # setup exponential moving average of model weights, SWA could be used here too
    model_ema = None
    if args.model_ema:
        # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
        model_ema = ModelEmaV2(
            model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
        if args.resume:
            load_checkpoint(model_ema.module, args.resume, use_ema=True)

    # setup distributed training
    if args.distributed:
        if has_apex and use_amp != 'native':
            # Apex DDP preferred unless native amp is activated
            if args.local_rank == 0:
                _logger.info("Using NVIDIA APEX DistributedDataParallel.")
            model = ApexDDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                _logger.info("Using native Torch DistributedDataParallel.")
            model = NativeDDP(model, device_ids=[args.local_rank])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    # setup learning rate schedule and starting epoch
    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    start_epoch = 0
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    # create the train and eval datasets
    dataset_train = create_dataset(
        args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
    dataset_eval = create_dataset(
        args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)

    # setup mixup / cutmix
    collate_fn = None
    mixup_fn = None
    mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
    if mixup_active:
        mixup_args = dict(
            mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
            prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
            label_smoothing=args.smoothing, num_classes=args.num_classes)
        if args.prefetcher:
            assert not num_aug_splits  # collate conflict (need to support deinterleaving in collate mixup)
            collate_fn = FastCollateMixup(**mixup_args)
        else:
            mixup_fn = Mixup(**mixup_args)

    # wrap dataset in AugMix helper
    if num_aug_splits > 1:
        dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)

    # create data loaders w/ augmentation pipeiine
    train_interpolation = args.train_interpolation
    if args.no_aug or not train_interpolation:
        train_interpolation = data_config['interpolation']
    loader_train = create_loader(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        no_aug=args.no_aug,
        re_prob=args.reprob,
        re_mode=args.remode,
        re_count=args.recount,
        re_split=args.resplit,
        scale=args.scale,
        ratio=args.ratio,
        hflip=args.hflip,
        vflip=args.vflip,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        num_aug_splits=num_aug_splits,
        interpolation=train_interpolation,
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        pin_memory=args.pin_mem,
        use_multi_epochs_loader=args.use_multi_epochs_loader
    )

    loader_eval = create_loader(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=args.validation_batch_size_multiplier * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        crop_pct=data_config['crop_pct'],
        pin_memory=args.pin_mem,
    )

    # setup loss function
    if args.jsd:
        assert num_aug_splits > 1  # JSD only valid with aug splits set
        train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
    elif mixup_active:
        # smoothing is handled with mixup target transform
        train_loss_fn = SoftTargetCrossEntropy().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
    validate_loss_fn = nn.CrossEntropyLoss().cuda()

    # setup checkpoint saver and eval metric tracking
    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"),
            args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        saver = CheckpointSaver(
            model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
            checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)

    ####################################################################################
    # Start SparseML Integration
    ####################################################################################
    sparseml_loggers = (
        [PythonLogger(), TensorBoardLogger(log_path=output_dir)]
        if output_dir
        else None
    )
    manager = ScheduledModifierManager.from_yaml(args.sparseml_recipe)
    optimizer = ScheduledOptimizer(
        optimizer,
        model,
        manager,
        steps_per_epoch=len(loader_train),
        loggers=sparseml_loggers
    )
    # override lr scheduler if recipe makes any LR updates
    if any("LearningRate" in str(modifier) for modifier in manager.modifiers):
        _logger.info("Disabling timm LR scheduler, managing LR using SparseML recipe")
        lr_scheduler = None
    if manager.max_epochs:
        _logger.info(
            f"Overriding max_epochs to {manager.max_epochs} from SparseML recipe"
        )
        num_epochs = manager.max_epochs or num_epochs
    ####################################################################################
    # End SparseML Integration
    ####################################################################################

    if args.local_rank == 0:
        _logger.info('Scheduled epochs: {}'.format(num_epochs))

    try:
        for epoch in range(start_epoch, num_epochs):
            if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
                loader_train.sampler.set_epoch(epoch)

            train_metrics = train_one_epoch(
                epoch, model, loader_train, optimizer, train_loss_fn, args,
                lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
                amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)

            if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                if args.local_rank == 0:
                    _logger.info("Distributing BatchNorm running means and vars")
                distribute_bn(model, args.world_size, args.dist_bn == 'reduce')

            eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)

            if model_ema is not None and not args.model_ema_force_cpu:
                if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
                    distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
                ema_eval_metrics = validate(
                    model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
                eval_metrics = ema_eval_metrics

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(
                epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)

        #################################################################################
        # Start SparseML ONNX Export
        #################################################################################
        if output_dir:
            _logger.info(
                f"training complete, exporting ONNX to {output_dir}/model.onnx"
            )
            exporter = ModuleExporter(model, output_dir)
            exporter.export_onnx(torch.randn((1, *data_config["input_size"])))
        #################################################################################
        # End SparseML ONNX Export
        #################################################################################

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
예제 #11
0
    from sparsezoo import Zoo

    if opt.weights == "zoo":
        # Load checkpoint from base weights associated with given SparseZoo recipe
        if opt.sparseml_recipe.startswith("zoo:"):
            opt.weights = Zoo.download_recipe_base_framework_files(
                opt.sparseml_recipe, extensions=[".pt", ".pth"])[0]
        else:
            raise ValueError(
                "Attempting to load weights from SparseZoo recipe, but not given a "
                "SparseZoo recipe stub.  When --weights is set to 'zoo'. "
                "sparseml-recipe must start with 'zoo:' and be a SparseZoo model "
                f"stub. sparseml-recipe was set to {opt.sparseml_recipe}")
    elif opt.weights.startswith("zoo:"):
        # Load weights from a SparseZoo model stub
        zoo_model = Zoo.load_model_from_stub(opt.weights)
        opt.weights = zoo_model.download_framework_files(
            extensions=[".pt", ".pth"])[0]
    ####################################################################################
    # End - SparseML optional load weights from SparseZoo
    ####################################################################################

    # Resume
    if opt.resume:  # resume an interrupted run
        ckpt = opt.resume if isinstance(
            opt.resume,
            str) else get_latest_run()  # specified or most recent path
        assert os.path.isfile(
            ckpt), 'ERROR: --resume checkpoint does not exist'
        apriori = opt.global_rank, opt.local_rank
        with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
예제 #12
0
def load_model(
    path: str,
    model: Module,
    strict: bool = False,
    ignore_error_tensors: List[str] = None,
    fix_data_parallel: bool = True,
):
    """
    Load the state dict into a model from a given file.

    :param path: the path to the pth file to load the state dict from.
        May also be a SparseZoo stub path preceded by 'zoo:' with the optional
        `?recipe_type=` argument. If given a recipe type, the base model weights
        for that recipe will be loaded.
    :param model: the model to load the state dict into
    :param strict: True to enforce that all tensors match between the model
        and the file; False otherwise
    :param ignore_error_tensors: names of tensors to ignore if they are not found
        in either the model or the file
    :param fix_data_parallel: fix the keys in the model state dict if they
        look like they came from DataParallel type setup (start with module.).
        This removes "module." all keys
    """
    if path.startswith("zoo:"):
        if "recipe_type=" in path:
            path = Zoo.download_recipe_base_framework_files(
                path, extensions=[".pth"])[0]
        else:
            path = Zoo.load_model_from_stub(path).download_framework_files(
                extensions=[".pth"])[0]
    model_dict = torch.load(path, map_location="cpu")
    current_dict = model.state_dict()

    if "state_dict" in model_dict:
        model_dict = model_dict["state_dict"]

    # check if any keys were saved through DataParallel type setup and convert those
    if fix_data_parallel:
        keys = [k for k in model_dict.keys()]
        module_key = "module."

        for key in keys:
            if key.startswith(module_key):
                new_key = key[len(module_key):]
                model_dict[new_key] = model_dict[key]
                del model_dict[key]

    if not ignore_error_tensors:
        ignore_error_tensors = []

    for ignore in ignore_error_tensors:
        if ignore not in model_dict and ignore not in current_dict:
            continue

        if (ignore in model_dict and ignore in current_dict
                and current_dict[ignore].shape != model_dict[ignore].shape):
            model_dict[ignore] = current_dict[ignore]
        elif ignore not in model_dict and ignore in current_dict:
            model_dict[ignore] = current_dict[ignore]
        elif ignore in model_dict and ignore not in current_dict:
            del model_dict[ignore]

    model.load_state_dict(model_dict, strict)