Ejemplo n.º 1
0
    def __init__(self, model, dummy_input, apply_scope_name_workarounds=True):
        self._src_model = model
        self._named_modules = OrderedDict(model.named_modules())
        self._adj_map = None
        self._layers_topological_order = None
        self._top_level_ops = set()
        model_clone = utility.make_non_parallel_copy(model)

        # Switch all instances of torch.nn.ModuleList in the model to our DistillerModuleList
        # See documentation of _DistillerModuleList class for details on why this is done
        model_clone, converted_module_names_map = _to_distiller_modulelist(model_clone)

        with torch.onnx.set_training(model_clone, False):
            
            device = utility.model_device(model_clone)
            dummy_input = utility.convert_tensors_recursively_to(dummy_input, device=device)
            self.dummy_input = dummy_input
            trace, _ = jit.get_trace_graph(model_clone, dummy_input, _force_outplace=True)

            # As of PyTorch 1.3.0, ONNX trace optimization has an issue that results in incorrect scope names
            # of nodes in the trace graph.
            # These can make it impossible, in some cases, to derive the connectivity of the model using the original
            # module names. So we try to detect these cases and apply workarounds

            # The issue is:
            #   Dropout ops are removed by ONNX trace optimization. However, the op BEFORE the original dropout op
            #   gets the scope name of the dropout op
            pre_dropout_nodes_scope_names = OrderedDict()

            prev_non_dropout_op = None
            for node in trace.graph().nodes():
                kind = node.kind()
                if 'aten' not in kind:
                    continue
                if kind == 'aten::dropout':
                    if prev_non_dropout_op:
                        pre_dropout_nodes_scope_names[node.scopeName()] = prev_non_dropout_op.scopeName()
                else:
                    prev_non_dropout_op = node

            # Let ONNX do the heavy lifting: fusing the convolution nodes; fusing the nodes
            # composing a GEMM operation; etc.
            torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)

            graph = trace.graph()
            self.ops = OrderedDict()
            self.module_ops_map = defaultdict(list)
            self.params = OrderedDict()
            self.edges = []
            self.temp = OrderedDict()

            in_out = list(graph.inputs()) + list(graph.outputs())
            for param in in_out:
                self.__add_param(param)

            for node in graph.nodes():
                new_op = self.__create_op(node)

                if apply_scope_name_workarounds:
                    # Here we apply the workaround to the issue of dropout op scope name overriding previous op's
                    # scope name
                    if new_op['name'] in pre_dropout_nodes_scope_names:
                        new_op['orig-name'] = pre_dropout_nodes_scope_names[new_op['name']]
                        new_op['name'] = new_op['orig-name']

                # Convert the graph node's scope name to a PyTorch module name
                module_name = onnx_name_2_pytorch_name(new_op['orig-name'])

                # Get name from before conversion to DistillerModuleList
                module_name = converted_module_names_map[module_name]

                if len(module_name) == 0:
                    # Special case where the module name is an empty string - this happens
                    # when the op is called from the "top-level" of the model
                    new_op['name'] = 'top_level_op'
                else:
                    new_op['name'] = module_name

                # Save the calling module name in the op dict. Denormalize it so it can
                # be directly matched with the actual model
                module_name = utility.denormalize_module_name(self._src_model, module_name)
                new_op['module-name'] = module_name

                # The node's scope name in the graph corresponds to the module from which the op was called.
                # This means that when ops are invoked from the same module via functional calls or direct
                # operations on tensors, these ops will have the SAME MODEL NAME associated with them.
                # For example:
                #   t = t1 + t2
                #   t = F.relu(t)
                # In this case the add operation and the ReLU operation will have the same name, which is
                # derived from the module they're contained in.
                #
                # Another case where different ops will have the same module name is when a module is reused:
                #   out = self.conv1(x)
                #   out = self.relu(out)    <=== First use of self.relu
                #   out = self.conv2(out)
                #   out = self.relu(out)    <=== Second use of self.relu
                # In this case the graph will have 2 distinct ReLU nodes, with the same scope name.
                #
                # Operators with the same name create very confusing graphs (in ResNet, for example),
                # so we "unroll" them.
                same_module_cnt = len(self.module_ops_map[module_name])
                if same_module_cnt:
                    # TODO: Was this meant to be applied only to 'top_level_ops'? Also, it's not
                    #       applied to the first module that had the same name
                    new_op['name'] += "_%s_%d" % (new_op['type'], same_module_cnt)
                self.module_ops_map[module_name].append(new_op['name'])

                # Finally we register the new op in the ops collection
                #print("new sgraph node - Scope name: {} ; Type: {} ; Display name {}".format(
                #    new_op['orig-name'], new_op['type'], new_op['name']))
                self.ops[new_op['name']] = new_op

                for input_ in node.inputs():
                    self.__add_input(new_op, input_)
                    self.edges.append(SummaryGraph.Edge(input_.debugName(), new_op['name']))

                for output in node.outputs():
                    self.__add_output(new_op, output)
                    self.edges.append(SummaryGraph.Edge(new_op['name'], output.debugName()))

                new_op['attrs'] = OrderedDict([(attr_name, node[attr_name]) for attr_name in node.attributeNames()])

        self.__merge_pad_avgpool()
        self.add_macs_attr()
        self.add_footprint_attr()
        self.add_arithmetic_intensity_attr()
        del model_clone
Ejemplo n.º 2
0
        #Note: load_lean_checkpoint is designed for testing only.
        #net = ckpt.load_lean_checkpoint(net, "/home/bwtseng/Downloads/vww_mobilenetv1_distiller/model_save/resnet50_pruned_85_best.pth.tar",
        #                   model_device=device)
        top1, top5, loss = test(net, criterion, device)
        #import torchvision.models as models
        #alexnet = alex.alexnet(pretrained=args.pre_trained)
        #net = torch.nn.DataParallel(net).cuda()
        #alexnet.to(device)
        #top1, top5, loss = test(alexnet, criterion, device)

        t, total = summary.weights_sparsity_tbl_summary(
            net, return_total_sparsity=True)
        print('Total sparsity: {:0.2f}\n'.format(total))
        if args.qe_calibration and not (args.test and args.quantize_eval):
            test_fn = partial(test, criterion=criterion, device=args.device)
            cmodel = utl.make_non_parallel_copy(net)
            collector.collect_quant_stats(cmodel,
                                          test_fn,
                                          classes=None,
                                          inplace_runtime_check=True,
                                          disable_inplace_attrs=True,
                                          save_dir="mobilenet_status")
        if args.post_qe_test:
            quantize_and_test_model(dataloaders['val'], net, criterion, args)

        #top1, loss = test(net, criterion, device)

    #net = train_model(net, criterion, optimizer, exp_lr_scheduler,
    #                       device, args.model_path, num_epochs=80)
    #net.to(device)
    #simple_train_for_distiller(net, criterion, optimizer, exp_lr_scheduler, compress_scheduler, device)