def _save_checkpoint(model, filename, optimizer_b=None, optimizer_g=None, optimizer_d=None, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) mmcv.mkdir_or_exist(osp.dirname(filename)) if is_module_wrapper(model): model = model.module checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(model.state_dict()) } # save optimizer state dict in the checkpoint if isinstance(optimizer_b, Optimizer): checkpoint['optimizer_b'] = optimizer_b.state_dict() elif isinstance(optimizer_b, dict): checkpoint['optimizer_b'] = {} for name, optim in optimizer_b.items(): checkpoint['optimizer_b'][name] = optim.state_dict() if isinstance(optimizer_g, Optimizer): checkpoint['optimizer_g'] = optimizer_g.state_dict() elif isinstance(optimizer_g, dict): checkpoint['optimizer_g'] = {} for name, optim in optimizer_g.items(): checkpoint['optimizer_g'][name] = optim.state_dict() if isinstance(optimizer_d, Optimizer): checkpoint['optimizer_d'] = optimizer_d.state_dict() elif isinstance(optimizer_d, dict): checkpoint['optimizer_d'] = {} for name, optim in optimizer_d.items(): checkpoint['optimizer_d'][name] = optim.state_dict() # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush()
def prune_mask_rcnn_only(args: PruneParams): """ Just prune without retraining. Args: args: (PruneParams). Returns: (MaskRCNN) pruned model in cuda. """ cfg = mmcv.Config.fromfile(args.config) model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg')) assert cfg.model['type'] == 'MaskRCNN', 'model type should be MaskRCNN!' # load checkpoint checkpoint = cp.load_checkpoint(model=model, filename=args.checkpoint) num_before = sum([p.nelement() for p in model.backbone.parameters()]) print('Before pruning, Backbone Params = %.2fM' % (num_before / 1E6)) # PRUNE FILTERS # func = {"ResNet50": prune_resnet50, "ResNet101": prune_resnet101} assert args.backbone in ['ResNet50', 'ResNet101'], "Wrong backbone type!" skip = { 'ResNet34': [2, 8, 14, 16, 26, 28, 30, 32], 'ResNet50': [2, 11, 20, 23, 89, 92, 95, 98], 'ResNet101': [2, 11, 20, 23, 89, 92, 95, 98] } pf_cfg, new_backbone = prune_top2_layers(arch=args.backbone, net=model.backbone, skip_block=skip[args.backbone], prs=str2list(args.prs), cuda=True) model.backbone = new_backbone num_after = sum([p.nelement() for p in model.backbone.parameters()]) print('After pruning: Backbone Params = %.2fM' % (num_after / 1E6)) print("Prune rate: %.2f%%" % ((num_before - num_after) / num_before * 100)) # replace checkpoint['state_dict'] checkpoint['state_dict'] = cp.weights_to_cpu(cp.get_state_dict(model)) mmcv.mkdir_or_exist(osp.dirname(args.result_path)) # save and immediately flush buffer torch.save(checkpoint, args.result_path) with open(args.result_path.split('.')[0] + '_cfg.txt', 'w') as f: f.write(str(pf_cfg))
def save_checkpoint(model, filename, optimizer=None, loss_scaler=None, save_apex_amp=False, meta=None): """Save checkpoint to file. The checkpoint will have 3 or more fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. In mixed-precision training, ``loss_scaler`` or ``amp.state_dict`` will be saved in checkpoint. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. loss_scaler (Object, optional): Loss scaler used for FP16 training. save_apex_amp (bool, optional): Whether to save apex.amp ``state_dict``. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError(f'meta must be a dict or None, but got {type(meta)}') meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) if is_module_wrapper(model): model = model.module if hasattr(model, 'CLASSES') and model.CLASSES is not None: # save class name to the meta meta.update(CLASSES=model.CLASSES) checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)) } # save optimizer state dict in the checkpoint if isinstance(optimizer, Optimizer): checkpoint['optimizer'] = optimizer.state_dict() elif isinstance(optimizer, dict): checkpoint['optimizer'] = {} for name, optim in optimizer.items(): checkpoint['optimizer'][name] = optim.state_dict() # save loss scaler for mixed-precision (FP16) training if loss_scaler is not None: checkpoint['loss_scaler'] = loss_scaler.state_dict() # save state_dict from apex.amp if save_apex_amp: from apex import amp checkpoint['amp'] = amp.state_dict() if filename.startswith('pavi://'): try: from pavi import modelcloud from pavi.exception import NodeNotFoundError except ImportError: raise ImportError( 'Please install pavi to load checkpoint from modelcloud.') model_path = filename[7:] root = modelcloud.Folder() model_dir, model_name = osp.split(model_path) try: model = modelcloud.get(model_dir) except NodeNotFoundError: model = root.create_training_model(model_dir) with TemporaryDirectory() as tmp_dir: checkpoint_file = osp.join(tmp_dir, model_name) with open(checkpoint_file, 'wb') as f: torch.save(checkpoint, f) f.flush() model.create_file(checkpoint_file, name=model_name) else: mmcv.mkdir_or_exist(osp.dirname(filename)) # immediately flush buffer with open(filename, 'wb') as f: torch.save(checkpoint, f) f.flush()