def prune_mask_rcnn_only(args: PruneParams): """ Just prune without retraining. Args: args: (PruneParams). Returns: (MaskRCNN) pruned model in cuda. """ cfg = mmcv.Config.fromfile(args.config) model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) assert cfg.model['type'] == 'MaskRCNN', 'model type should be MaskRCNN!' # load checkpoint checkpoint = cp.load_checkpoint(model=model, filename=args.checkpoint) num_before = sum([p.nelement() for p in model.backbone.parameters()]) print('Before pruning, Backbone Params = %.2fM' % (num_before / 1E6)) # PRUNE FILTERS # func = {"ResNet50": prune_resnet50, "ResNet101": prune_resnet101} assert args.backbone in ['ResNet50', 'ResNet101'], "Wrong backbone type!" skip = { 'ResNet34': [2, 8, 14, 16, 26, 28, 30, 32], 'ResNet50': [2, 11, 20, 23, 89, 92, 95, 98], 'ResNet101': [2, 11, 20, 23, 89, 92, 95, 98] } pf_cfg, new_backbone = prune_top2_layers(arch=args.backbone, net=model.backbone, skip_block=skip[args.backbone], prs=str2list(args.prs), cuda=True) model.backbone = new_backbone num_after = sum([p.nelement() for p in model.backbone.parameters()]) print('After pruning: Backbone Params = %.2fM' % (num_after / 1E6)) print("Prune rate: %.2f%%" % ((num_before - num_after) / num_before * 100)) # replace checkpoint['state_dict'] checkpoint['state_dict'] = cp.weights_to_cpu(cp.get_state_dict(model)) mmcv.mkdir_or_exist(osp.dirname(args.result_path)) # save and immediately flush buffer torch.save(checkpoint, args.result_path) with open(args.result_path.split('.')[0] + '_cfg.txt', 'w') as f: f.write(str(pf_cfg))
def test_get_state_dict(): if torch.__version__ == 'parrots': state_dict_keys = set([ 'block.conv.weight', 'block.conv.bias', 'block.norm.weight', 'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var', 'conv.weight', 'conv.bias' ]) else: state_dict_keys = set([ 'block.conv.weight', 'block.conv.bias', 'block.norm.weight', 'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var', 'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias' ]) model = Model() state_dict = get_state_dict(model) assert isinstance(state_dict, OrderedDict) assert set(state_dict.keys()) == state_dict_keys assert_tensor_equal(state_dict['block.conv.weight'], model.block.conv.weight) assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias) assert_tensor_equal(state_dict['block.norm.weight'], model.block.norm.weight) assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias) assert_tensor_equal(state_dict['block.norm.running_mean'], model.block.norm.running_mean) assert_tensor_equal(state_dict['block.norm.running_var'], model.block.norm.running_var) if torch.__version__ != 'parrots': assert_tensor_equal(state_dict['block.norm.num_batches_tracked'], model.block.norm.num_batches_tracked) assert_tensor_equal(state_dict['conv.weight'], model.conv.weight) assert_tensor_equal(state_dict['conv.bias'], model.conv.bias) wrapped_model = DDPWrapper(model) state_dict = get_state_dict(wrapped_model) assert isinstance(state_dict, OrderedDict) assert set(state_dict.keys()) == state_dict_keys assert_tensor_equal(state_dict['block.conv.weight'], wrapped_model.module.block.conv.weight) assert_tensor_equal(state_dict['block.conv.bias'], wrapped_model.module.block.conv.bias) assert_tensor_equal(state_dict['block.norm.weight'], wrapped_model.module.block.norm.weight) assert_tensor_equal(state_dict['block.norm.bias'], wrapped_model.module.block.norm.bias) assert_tensor_equal(state_dict['block.norm.running_mean'], wrapped_model.module.block.norm.running_mean) assert_tensor_equal(state_dict['block.norm.running_var'], wrapped_model.module.block.norm.running_var) if torch.__version__ != 'parrots': assert_tensor_equal( state_dict['block.norm.num_batches_tracked'], wrapped_model.module.block.norm.num_batches_tracked) assert_tensor_equal(state_dict['conv.weight'], wrapped_model.module.conv.weight) assert_tensor_equal(state_dict['conv.bias'], wrapped_model.module.conv.bias) # wrapped inner module for name, module in wrapped_model.module._modules.items(): module = DataParallel(module) wrapped_model.module._modules[name] = module state_dict = get_state_dict(wrapped_model) assert isinstance(state_dict, OrderedDict) assert set(state_dict.keys()) == state_dict_keys assert_tensor_equal(state_dict['block.conv.weight'], wrapped_model.module.block.module.conv.weight) assert_tensor_equal(state_dict['block.conv.bias'], wrapped_model.module.block.module.conv.bias) assert_tensor_equal(state_dict['block.norm.weight'], wrapped_model.module.block.module.norm.weight) assert_tensor_equal(state_dict['block.norm.bias'], wrapped_model.module.block.module.norm.bias) assert_tensor_equal(state_dict['block.norm.running_mean'], wrapped_model.module.block.module.norm.running_mean) assert_tensor_equal(state_dict['block.norm.running_var'], wrapped_model.module.block.module.norm.running_var) if torch.__version__ != 'parrots': assert_tensor_equal( state_dict['block.norm.num_batches_tracked'], wrapped_model.module.block.module.norm.num_batches_tracked) assert_tensor_equal(state_dict['conv.weight'], wrapped_model.module.conv.module.weight) assert_tensor_equal(state_dict['conv.bias'], wrapped_model.module.conv.module.bias)
def save_checkpoint(model, filename, optimizer=None, loss_scaler=None, save_apex_amp=False, meta=None): """Save checkpoint to file. The checkpoint will have 3 or more fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. In mixed-precision training, ``loss_scaler`` or ``amp.state_dict`` will be saved in checkpoint. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. loss_scaler (Object, optional): Loss scaler used for FP16 training. save_apex_amp (bool, optional): Whether to save apex.amp ``state_dict``. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) if is_module_wrapper(model): model = model.module if hasattr(model, 'CLASSES') and model.CLASSES is not None: # save class name to the meta meta.update(CLASSES=model.CLASSES) checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() # save loss scaler for mixed-precision (FP16) training if loss_scaler is not None: checkpoint['loss_scaler'] = loss_scaler.state_dict() # save state_dict from apex.amp if save_apex_amp: from apex import amp checkpoint['amp'] = amp.state_dict() if filename.startswith('pavi://'): try: from pavi import modelcloud from pavi.exception import NodeNotFoundError except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') model_path = filename[7:] root = modelcloud.Folder() model_dir, model_name = osp.split(model_path) try: model = modelcloud.get(model_dir) except NodeNotFoundError: model = root.create_training_model(model_dir) with TemporaryDirectory() as tmp_dir: checkpoint_file = osp.join(tmp_dir, model_name) with open(checkpoint_file, 'wb') as f: torch.save(checkpoint, f) f.flush() model.create_file(checkpoint_file, name=model_name) else: mmcv.mkdir_or_exist(osp.dirname(filename)) # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush()