示例#1
0
    def __init__(self,
                 model1: nn.Module,
                 model2: nn.Module,
                 method: str | Callable = 'pwcca',
                 model1_names: str | list[str] = None,
                 model2_names: str | list[str] = None,
                 model1_leaf_modules: list[nn.Module] = None,
                 model2_leaf_modules: list[nn.Module] = None,
                 train_mode: bool = False
                 ):

        dp_ddp = (nn.DataParallel, nn.parallel.DistributedDataParallel)
        if isinstance(model1, dp_ddp) or isinstance(model2, dp_ddp):
            raise RuntimeWarning('model is nn.DataParallel or nn.DistributedDataParallel. '
                                 'SimilarityHook may causes unexpected behavior.')
        if isinstance(method, str):
            method = self._default_backends[method]
        self.distance_func = method
        self.model1 = model1
        self.model2 = model2
        self.extractor1 = create_feature_extractor(model1, self.convert_names(model1, model1_names,
                                                                              model1_leaf_modules, train_mode))
        self.extractor2 = create_feature_extractor(model2, self.convert_names(model2, model2_names,
                                                                              model2_leaf_modules, train_mode))
        self._model1_tensors: dict[str, torch.Tensor] = None
        self._model2_tensors: dict[str, torch.Tensor] = None
def _create_fx_model(model, train=False):
    # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
    # So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
    # node. Then we use those indices to select from train_nodes returned by torchvision get_graph_node_names
    tracer_kwargs = dict(
        leaf_modules=list(_leaf_modules),
        autowrap_functions=list(_autowrap_functions),
        #enable_cpatching=True,
        param_shapes_constant=True)
    train_nodes, eval_nodes = get_graph_node_names(model,
                                                   tracer_kwargs=tracer_kwargs)

    eval_return_nodes = [eval_nodes[-1]]
    train_return_nodes = [train_nodes[-1]]
    if train:
        tracer = NodePathTracer(**tracer_kwargs)
        graph = tracer.trace(model)
        graph_nodes = list(reversed(graph.nodes))
        output_node_names = [
            n.name for n in graph_nodes[0]._input_nodes.keys()
        ]
        graph_node_names = [n.name for n in graph_nodes]
        output_node_indices = [
            -graph_node_names.index(node_name)
            for node_name in output_node_names
        ]
        train_return_nodes = [train_nodes[ix] for ix in output_node_indices]

    fx_model = create_feature_extractor(
        model,
        train_return_nodes=train_return_nodes,
        eval_return_nodes=eval_return_nodes,
        tracer_kwargs=tracer_kwargs,
    )
    return fx_model
 def _create_feature_extractor(self, *args, **kwargs):
     """
     Apply leaf modules
     """
     tracer_kwargs = {}
     if "tracer_kwargs" not in kwargs:
         tracer_kwargs = {"leaf_modules": self.leaf_modules}
     else:
         tracer_kwargs = kwargs.pop("tracer_kwargs")
     return create_feature_extractor(*args, **kwargs, tracer_kwargs=tracer_kwargs, suppress_diff_warning=True)
示例#4
0
 def reset(self):
     if self.model is None:
         gpu_ids = self.current_resource.get('gpu', 1)
         if len(gpu_ids) >= 1:
             self.device = 'cuda:%d' % (gpu_ids[0])
             cudnn.fastest = True
             cudnn.benchmark = True
         else:
             self.device = 'cpu'
             self.logger.warn('No available GPUs, running on CPU.')
         base_model = getattr(models, self.model_name)(pretrained=True)
         self.model = create_feature_extractor(base_model,
                                               {self.node_name: 'feature'})
         self.model = self.model.to(self.device).eval()
示例#5
0
 def __init__(self, model, out_indices, out_map=None):
     super().__init__()
     assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
     self.feature_info = _get_feature_info(model, out_indices)
     if out_map is not None:
         assert len(out_map) == len(out_indices)
     return_nodes = {
         info['module']:
         out_map[i] if out_map is not None else info['module']
         for i, info in enumerate(self.feature_info) if i in out_indices
     }
     self.graph_module = create_feature_extractor(
         model,
         return_nodes,
         tracer_kwargs={
             'leaf_modules': list(_leaf_modules),
             'autowrap_functions': list(_autowrap_functions)
         })
示例#6
0
def create_model(num_classes):
    import torchvision
    from torchvision.models.feature_extraction import create_feature_extractor

    # vgg16
    backbone = torchvision.models.vgg16_bn(pretrained=False)
    # print(backbone)
    backbone = create_feature_extractor(backbone, return_nodes={"features.42": "0"})
    # out = backbone(torch.rand(1, 3, 224, 224))
    # print(out["0"].shape)
    backbone.out_channels = 512

    # resnet50 backbone
    # backbone = torchvision.models.resnet50(pretrained=False)
    # # print(backbone)
    # backbone = create_feature_extractor(backbone, return_nodes={"layer3": "0"})
    # # out = backbone(torch.rand(1, 3, 224, 224))
    # # print(out["0"].shape)
    # backbone.out_channels = 1024

    # EfficientNetB0
    # backbone = torchvision.models.efficientnet_b0(pretrained=False)
    # # print(backbone)
    # backbone = create_feature_extractor(backbone, return_nodes={"features.5": "0"})
    # # out = backbone(torch.rand(1, 3, 224, 224))
    # # print(out["0"].shape)
    # backbone.out_channels = 112

    anchor_generator = AnchorsGenerator(sizes=((32, 64, 128, 256, 512),),
                                        aspect_ratios=((0.5, 1.0, 2.0),))

    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],  # 在哪些特征层上进行RoIAlign pooling
                                                    output_size=[7, 7],  # RoIAlign pooling输出特征矩阵尺寸
                                                    sampling_ratio=2)  # 采样率

    model = FasterRCNN(backbone=backbone,
                       num_classes=num_classes,
                       rpn_anchor_generator=anchor_generator,
                       box_roi_pool=roi_pooler)

    return model
            args.num_classes = 91

    else:
        kwargs = {}
        if args.num_classes and not segmentation:
            logging.info("Using num_classes = %d" % args.num_classes)
            kwargs["num_classes"] = args.num_classes

        model = model_classes[mname](pretrained=args.pretrained, progress=args.verbose, **kwargs)

        if args.extract != None:
            logging.info("Extract layers: " + ", ".join(args.extract))
            return_nodes = {
                layer: layer for layer in args.extract
            }
            model = create_feature_extractor(model, return_nodes=return_nodes)

        if segmentation and 'deeplabv3' in mname:
            model.classifier = DeepLabHead(2048, args.num_classes)
        
        if args.to_dd_native:
            # Make model NativeModuleWrapper compliant
            model = Wrapper(model)

        model.eval()

        # tracing or scripting model (default)
        if args.trace:
            example = get_image_input(args.batch_size, args.img_width, args.img_height) 
            script_module = torch.jit.trace(model, example)
        else: