def __init__(self, ARCH, nclasses, path=None, path_append="", strict=False): super().__init__() self.ARCH = ARCH self.nclasses = nclasses self.path = path self.path_append = path_append self.strict = False # get the model cur_dir = pathlib.Path(__file__).parent.absolute() spec = importlib.util.spec_from_file_location( "BackboneModule", cur_dir.joinpath("../../../backbones/" + self.ARCH["backbone"]["name"] + ".py")) backbone_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(backbone_module) self.backbone = backbone_module.Backbone(params=self.ARCH["backbone"]) # do a pass of the backbone to initialize the skip connections stub = torch.zeros( (1, self.backbone.get_input_depth(), self.ARCH["dataset"]["sensor"]["img_prop"]["height"], self.ARCH["dataset"]["sensor"]["img_prop"]["width"])) if torch.cuda.is_available(): stub = stub.cuda() self.backbone.cuda() _, stub_skips = self.backbone(stub) decoder_spec = importlib.util.spec_from_file_location( "DecoderModule", cur_dir.joinpath("../decoders/" + self.ARCH["decoder"]["name"] + ".py")) decoder_module = importlib.util.module_from_spec(decoder_spec) decoder_spec.loader.exec_module(decoder_module) self.decoder = decoder_module.Decoder( params=self.ARCH["decoder"], stub_skips=stub_skips, OS=self.ARCH["backbone"]["OS"], feature_depth=self.backbone.get_last_depth()) self.head = nn.Sequential( nn.Dropout2d(p=ARCH["head"]["dropout"]), nn.Conv2d(self.decoder.get_last_depth(), self.nclasses, kernel_size=3, stride=1, padding=1)) if self.ARCH["post"]["CRF"]["use"]: self.CRF = CRF(self.ARCH["post"]["CRF"]["params"], self.nclasses) else: self.CRF = None # train backbone? if not self.ARCH["backbone"]["train"]: for w in self.backbone.parameters(): w.requires_grad = False # train decoder? if not self.ARCH["decoder"]["train"]: for w in self.decoder.parameters(): w.requires_grad = False # train head? if not self.ARCH["head"]["train"]: for w in self.head.parameters(): w.requires_grad = False # train CRF? if self.CRF and not self.ARCH["post"]["CRF"]["train"]: for w in self.CRF.parameters(): w.requires_grad = False # print number of parameters and the ones requiring gradients # print number of parameters and the ones requiring gradients weights_total = sum(p.numel() for p in self.parameters()) weights_grad = sum(p.numel() for p in self.parameters() if p.requires_grad) print("Total number of parameters: ", weights_total) print("Total number of parameters requires_grad: ", weights_grad) # breakdown by layer weights_enc = sum(p.numel() for p in self.backbone.parameters()) weights_dec = sum(p.numel() for p in self.decoder.parameters()) weights_head = sum(p.numel() for p in self.head.parameters()) print("Param encoder ", weights_enc) print("Param decoder ", weights_dec) print("Param head ", weights_head) if self.CRF: weights_crf = sum(p.numel() for p in self.CRF.parameters()) print("Param CRF ", weights_crf) # get weights if path is not None: # try backbone try: w_dict = torch.load(path + "/backbone" + path_append, map_location=lambda storage, loc: storage) self.backbone.load_state_dict(w_dict, strict=True) print("Successfully loaded model backbone weights") except Exception as e: print() print("Couldn't load backbone, using random weights. Error: ", e) if strict: print( "I'm in strict mode and failure to load weights blows me up :)" ) raise e # try decoder try: w_dict = torch.load(path + "/segmentation_decoder" + path_append, map_location=lambda storage, loc: storage) self.decoder.load_state_dict(w_dict, strict=True) print("Successfully loaded model decoder weights") except Exception as e: print("Couldn't load decoder, using random weights. Error: ", e) if strict: print( "I'm in strict mode and failure to load weights blows me up :)" ) raise e # try head try: w_dict = torch.load(path + "/segmentation_head" + path_append, map_location=lambda storage, loc: storage) self.head.load_state_dict(w_dict, strict=True) print("Successfully loaded model head weights") except Exception as e: print("Couldn't load head, using random weights. Error: ", e) if strict: print( "I'm in strict mode and failure to load weights blows me up :)" ) raise e # try CRF if self.CRF: try: w_dict = torch.load( path + "/segmentation_CRF" + path_append, map_location=lambda storage, loc: storage) self.CRF.load_state_dict(w_dict, strict=True) print("Successfully loaded model CRF weights") except Exception as e: print("Couldn't load CRF, using random weights. Error: ", e) if strict: print( "I'm in strict mode and failure to load weights blows me up :)" ) raise e else: print("No path to pretrained, using random init.")