def run_opt(self, model, max_num_params, max_num_ops, strategy="MAX_COST", **kwargs): """ Run the actual optimization. Model is already loaded in this case. All the modules that should be grouped are already replaced by MaskConv2d. We will update the G value in each module in place. """ logging.info("Finding the optimal group configuration ...") logging.info("Max # param: {:.2f} M".format(max_num_params)) logging.info("Max # ops: {:.2f} M".format(max_num_ops)) c_dict = OrderedDict() # candidate dictionary state = OrderedDict() # G index for each group logging.debug("==> Constructing the candidate dictionary ...") for name, mod in model.named_modules(): if isinstance(mod, MaskConv2d): Gs, costs = list(self.find_group_candidates(mod, **kwargs)) c_dict[name] = (Gs, costs, self.get_mod_norm(mod)) state[name] = len(Gs) - 1 # the last group # Initial state self.assign_state_to_model(model, state, c_dict) logging.debug( "Initial state: {:.2f} M params {:.2f} M ops cost: {:.2f}%".format( model_utils.get_model_num_params(model), utils.get_model_num_ops(model, self.args.dataset), (1 - self.get_cost(state, c_dict)) * 100, )) # Now, depends on the strategy, we will update the group setting if strategy == "ILP": # HACK - need to refactorise the code # state = self.ilp_solve(c_dict) raise NotImplementedError( "ILP as a optimisation strategy is not implemented") elif strategy == "MAX_COST": state, cost = self.max_cost_solver(model, state, c_dict, max_num_params, max_num_ops) else: raise ValueError("Cannot recognise strategy: {}".format(strategy)) # finalise self.assign_state_to_model(model, state, c_dict) num_params = model_utils.get_model_num_params(model) num_ops = utils.get_model_num_ops(model, self.args.dataset) logging.debug( "Final state - params: {:.2f} M ops: {:6.2f} M cost: {:.2f}". format(num_params, num_ops, (1 - cost) * 100))
def max_cost_solver(self, model, state, c_dict, max_num_params, max_num_ops): num_params = model_utils.get_model_num_params(model) num_ops = utils.get_model_num_ops(model, self.args.dataset) cost = self.get_cost(state, c_dict) if num_params > max_num_params or num_ops > max_num_ops: raise ValueError( 'The maximal constraints are too restrictive: params {:.2f} ops {:.2f}.' .format(num_params, num_ops)) step = 0 while num_params <= max_num_params and num_ops <= max_num_ops: # we can have a try costs = [] for name, gi in state.items(): # try reduce if gi == 0: continue state_ = state.copy() state_[name] = gi - 1 # this will definitely increase the cost cost_ = self.get_cost(state_, c_dict) costs.append((name, cost_)) if not costs: logging.debug('No more choices for MAX_COST, exiting ...') break # find the max cost update max_cost = max(costs, key=lambda k: k[1]) # print(costs) # print(max_cost) # update the state state[max_cost[0]] -= 1 cost = max_cost[1] # update self.assign_state_to_model(model, state, c_dict) if use_cuda: # post-fix model placing model.cuda() num_params = model_utils.get_model_num_params(model) num_ops = utils.get_model_num_ops(model, self.args.dataset) step += 1 logging.debug( '[{:4d}] Current state - params: {:.2f} M ops: {:6.2f} M cost: {:.2f}%' .format(step, num_params, num_ops, (1 - cost) * 100)) return state, cost
def main(): parser = create_cli_parser() args = parser.parse_args() set_random_seed(42) cfg = GumiConfig(**vars(args)) cfg = patch_dataset(cfg) logging.info("==> Initializing ModelPruner") model_pruner = ModelPruner(cfg) model = model_pruner.load_model(update_state_dict_fn=update_state_dict) model_pruner.validate(model, record_top5=False) logging.info("==> Replacing Conv2d in model by MaskConv2d ...") utils.apply_mask(model) model_pruner.validate(model, record_top5=False) logging.info("==> Pruning model ...") get_num_groups = group_utils.create_get_num_groups_fn( G=cfg.num_groups, MCPG=cfg.mcpg, group_cfg=cfg.group_cfg) model_pruner.prune( model, get_num_groups_fn=get_num_groups, perm=cfg.perm, no_weight=cfg.no_weight, num_iters=cfg.num_sort_iters, keep_mask=False, ) model_pruner.validate(model, record_top5=False) logging.info("==> Fine-tuning model ...") model_pruner.fine_tune(model, record_top5=False) model_pruner.validate(model, record_top5=False) logging.info("==> Exporting the model ...") model = GroupExporter.export(model) logging.debug("Total params: {:.2f}M FLOPS: {:.2f}M".format( model_utils.get_model_num_params(model), utils.get_model_num_ops(model, args.dataset), )) logging.info("==> Saving exported model ...") torch.save(model, os.path.join(cfg.checkpoint, "pruned.pth.tar"))
def main(): use_cuda = not args.cpu print("==> Loading model {}".format(args.arch)) model = utils.load_model(args.arch, "imagenet", use_cuda=use_cuda, pretrained=True) print("==> Loading group config {}".format(args.group_cfg)) with open(args.group_cfg, "r") as f: group_cfg = json.load(f) print("==> Updating model ...") model = update_model(model, group_cfg, use_cuda=use_cuda) print(model) print( "==> Model size: {:.2f} M ops: {:.2f} M".format( model_utils.get_model_num_params(model), utils.get_model_num_ops(model, "imagenet"), ) ) torch.save(model, os.path.join(os.path.dirname(args.resume), "peng.pth.tar"))
def profile(self, iters, use_cuda=False): """ Run the profile function """ groups = [1, 2, 4, 8, 16, 32, 64] sparse = (True, False) data = [] columns = ['G', 'is_sparse', 'num_params', 'num_ops', 'time'] for g in groups: for is_sparse in sparse: if g != 1 and not is_sparse: continue model = MobileNet(groups=g, sparse=is_sparse, num_channels=1000) update_sparse_weight(model, block=True) num_params = model_utils.get_model_num_params(model) num_ops = utils.get_model_num_ops(model, self.args.dataset) x = torch.rand((1, 3, 224, 224)) if use_cuda: x = x.cuda() model.cuda() logging.info( '==> Profiling G={} sparse={} for {} iters ...'.format( g, is_sparse, iters)) elapsed = self.profile_case(x, model, iters=iters) logging.info( '\t# params. {:.2f}M # ops {:.2f}M Time elapsed: {:.2f} ms ' .format(num_params, num_ops, elapsed)) data.append([g, is_sparse, num_params, num_ops, elapsed]) return pd.DataFrame(data, columns=columns)
def main(): """ Main """ logging.info("==> Initializing ModelPruner ...") model_pruner = ModelPruner(args) # load model logging.info("==> Loading model ...") update_model_fn = create_update_model_fn( args.arch, args.dataset, args.pretrained, args.resume, apply_mask=args.apply_mask, condensenet=args.condensenet, ) model = model_pruner.load_model( update_model_fn=update_model_fn, update_state_dict_fn=create_update_state_dict_fn( no_mask=not args.resume, condensenet=args.condensenet), fine_tune=args.fine_tune, ) # evaluate the performance of the model in the beginning if not args.skip_validation: logging.info("==> Validating the loaded model ...") loss1, acc1 = model_pruner.validate(model) ################################################# # Pruning # # # ################################################# if not args.apply_mask: # NOTE: we have not applied mask yet # # major pruning function logging.info("==> Replacing Conv2d in model by MaskConv2d ...") # TODO - duplicated with update_model_fn? # not quite, if not resume the model won't be updated utils.apply_mask(model) if not args.skip_validation: logging.info("==> Validating the masked model ...") loss2, acc2 = model_pruner.validate(model) assert torch.allclose(acc1, acc2) # run pruning (update the content of mask) logging.info("==> Pruning model ...") if not args.skip_prune: get_num_groups = create_get_num_groups_fn(G=args.num_groups, MCPG=args.mcpg, group_cfg=args.group_cfg) logging.debug("Pruning configuration:") logging.debug("PERM: {}".format(args.perm)) logging.debug("NS: {}".format(args.num_sort_iters)) logging.debug("No weight: {}".format(args.no_weight)) logging.debug("Keep mask: {}".format(args.keep_mask)) logging.debug("") model_pruner.prune( model, get_num_groups_fn=get_num_groups, perm=args.perm, no_weight=args.no_weight, num_iters=args.num_sort_iters, keep_mask=args.keep_mask, ) if not args.skip_validation: logging.info("==> Validating the pruned model ...") loss3, acc3 = model_pruner.validate(model) else: logging.info("Pruning has been skipped, you have the original model.") ################################################# # Fine-tuning # # # ################################################# logging.info("==> Fine-tuning the pruned model ...") if args.train_from_scratch: logging.info("==> Training the pruned topology from scratch ...") # reset weight parameters # TODO: refactorize for name, mod in model.named_modules(): if hasattr( mod, "weight") and len(mod.weight.shape) >= 2: # re-initialize torch.nn.init.kaiming_normal_(mod.weight, nonlinearity="relu") # if hasattr(mod, 'G'): # mod.weight.data.mul_(mod.G) if hasattr(mod, "bias") and mod.bias is not None: mod.bias.data.fill_(0.0) if not args.skip_fine_tune: model_pruner.fine_tune(model) if not args.skip_validation: logging.info("==> Validating the fine-tuned pruned model ...") loss4, acc4 = model_pruner.validate(model) logging.info( "==> Final validation accuracy of the pruned model: {:.2f}%". format(acc4)) else: logging.info("Fine-tuning has been skipped.") ################################################# # Export # # # ################################################# logging.info("==> Exporting the model ...") model = GroupExporter.export(model) if use_cuda: model.cuda() logging.debug("Total params: {:.2f}M FLOPS: {:.2f}M".format( model_utils.get_model_num_params(model), utils.get_model_num_ops(model, args.dataset), )) if not args.skip_validation: logging.info("==> Validating the exported pruned model ...") loss5, acc5 = model_pruner.validate(model) logging.info( "==> Final validation accuracy of the exported model: {:.2f}%". format(acc5))
def main(): """ Load the model and prune it """ logging.info('==> Initializing GroupModelExporter ...') exporter = GroupModelExporter(args) # load model logging.info('==> Loading model ...') model = exporter.load_model( use_cuda=True, data_parallel=False, update_model_fn=create_update_model_fn(), update_state_dict_fn=create_update_state_dict_fn()) # prune it to setup each MaskConv2d get_num_groups_fn = utils.create_get_num_groups_fn( G=args.num_groups, MCPG=args.mcpg, group_cfg=args.group_cfg, use_cuda=True, data_parallel=False) exporter.prune(model, get_num_groups_fn=get_num_groups_fn, keep_mask=True, use_cuda=True) for name, mod in model.named_modules(): if isinstance(mod, MaskConv2d): print(name, mod.G) # export logging.info('==> Exporting the model ...') model = GroupExporter.export(model, use_cuda=True, mm=args.mm, sparse=args.sparse, std=args.std, min_sparse_channels=args.min_sparse_channels) # evaluate if args.val: exporter.validate(model) # move back to CPU model.cpu() # print(model) logging.debug('Total params: {:.2f}M FLOPS: {:.2f}M'.format( model_utils.get_model_num_params(model), utils.get_model_num_ops(model, args.dataset))) logging.debug('==> Densify model ...') utils.apply_dense(model) # save model if not args.onnx: suffix = '' if args.mm: suffix += '_mm' if args.sparse: suffix += '_sparse' if args.std: suffix += '_std' if args.min_sparse_channels > 0: suffix += '_min{}'.format(args.min_sparse_channels) fp = os.path.join(os.path.dirname(args.resume), 'gconv{}.pth.tar'.format(suffix)) logging.info('==> Saving the model to {} ...'.format(fp)) torch.save(model, fp) else: fp = os.path.join(os.path.dirname(args.resume), 'gconv.onnx') logging.info('==> Exporting model to ONNX {} ...'.format(fp)) if args.dataset in utils.IMAGENET_DATASETS: x = torch.randn(1, 3, 224, 224, requires_grad=True) else: raise RuntimeError('Do not support {}'.format(args.dataset)) onnx = torch.onnx._export(model, x, fp, export_params=True)
def main(): use_cuda = not args.cpu if args.pretrained: logging.info('==> Loading pre-trained model ...') model = utils.load_model( args.arch, args.dataset, pretrained=args.pretrained, use_cuda=use_cuda) else: logging.info('==> Loading GConv model directly from Pickle ...') model = torch.load(args.resume) if not use_cuda: model.cpu() model.eval() logging.info('==> Sparsify model ...') utils.apply_sparse(model) print(model) logging.debug('Total params: {:.2f}M FLOPS: {:.2f}M'.format( model_utils.get_model_num_params(model), utils.get_model_num_ops(model, args.dataset))) if args.dataset in utils.IMAGENET_DATASETS: x = torch.rand((args.test_batch, 3, 224, 224)) else: x = torch.rand((args.test_batch, 3, 32, 32)) if not use_cuda: x = x.cpu() # setup the input and model if use_cuda: x = x.cuda() model.cuda() logging.info('==> Dry running ...') dry_run_iters = 10 if not use_cuda else 100 for _ in range(dry_run_iters): y = model.forward(x) logging.info('==> Print profiling info ...') with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof: y = model.forward(x) print(prof) # start timing logging.info('==> Start timing ...') if use_cuda: start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(args.iters): y = model.forward(x) end.record() # synchronize torch.cuda.synchronize() elapsed = start.elapsed_time(end) else: start = time.time() for _ in range(args.iters): y = model.forward(x) end = time.time() elapsed = (end - start) * 1e3 print('Elapsed time: {:10.2f} sec (total) {:6.2f} ms (per run) {:6.2f} FPS.'. format(elapsed * 1e-3, elapsed / args.iters, args.iters * args.test_batch / elapsed * 1e3))
def get_model_size(self, model): """ Return model number of parameters and ops """ num_params = model_utils.get_model_num_params(model) num_ops = utils.get_model_num_ops(model, self.args.dataset) return num_params, num_ops