def make_model(opts): if opts.norm_act == 'iabn_sync': norm = partial(InPlaceABNSync, activation="leaky_relu", activation_param=.01) elif opts.norm_act == 'iabn': norm = partial(InPlaceABN, activation="leaky_relu", activation_param=.01) else: norm = partial(ABN, activation="leaky_relu", activation_param=.01) body = models.__dict__[f'net_{opts.backbone}'](norm_act=norm, output_stride=opts.output_stride) if not opts.no_pretrained: pretrained_path = f'pretrained/{opts.backbone}_{opts.norm_act}.pth.tar' pre_dict = torch.load(pretrained_path, map_location='cpu') del pre_dict['state_dict']['classifier.fc.weight'] del pre_dict['state_dict']['classifier.fc.bias'] body.load_state_dict(pre_dict['state_dict']) del pre_dict # free memory head_channels = 256 if opts.deeplab == 'v3': head = DeeplabV3(body.out_channels, head_channels, 256, norm_act=norm, out_stride=opts.output_stride, pooling_size=opts.pooling) elif opts.deeplab == 'v2': head = DeeplabV2(body.out_channels, head_channels, norm_act=norm, out_stride=opts.output_stride) else: raise NotImplementedError("Specify a correct head.") model = SegmentationModule(body, head, head_channels, opts.num_classes) return model
def load_snapshot(snapshot_file): """Load a training snapshot""" print("--- Loading model from snapshot") # Create network norm_act = partial(InPlaceABN, activation="leaky_relu", activation_param=.01) body = models.__dict__["net_wider_resnet38_a2"](norm_act=norm_act, dilation=(1, 2, 4, 4)) head = DeeplabV3(4096, 256, 256, norm_act=norm_act, pooling_size=(84, 84)) # Load snapshot and recover network state data = torch.load(snapshot_file) body.load_state_dict(data["state_dict"]["body"]) head.load_state_dict(data["state_dict"]["head"]) return body, head, data["state_dict"]["cls"]
def make_model(opts, classes=None): if opts.norm_act == 'iabn_sync': norm = partial(InPlaceABNSync, activation="leaky_relu", activation_param=.01) elif opts.norm_act == 'iabn': norm = partial(InPlaceABN, activation="leaky_relu", activation_param=.01) elif opts.norm_act == 'abn': norm = partial(ABN, activation="leaky_relu", activation_param=.01) else: norm = nn.BatchNorm2d # not synchronized, can be enabled with apex body = models.__dict__[f'net_{opts.backbone}']( norm_act=norm, output_stride=opts.output_stride) if not opts.no_pretrained: pretrained_path = f'pretrained/{opts.backbone}_{opts.norm_act}.pth.tar' pre_dict = torch.load(pretrained_path, map_location='cpu') del pre_dict['state_dict']['classifier.fc.weight'] del pre_dict['state_dict']['classifier.fc.bias'] body.load_state_dict(pre_dict['state_dict']) del pre_dict # free memory head_channels = 256 head = DeeplabV3(body.out_channels, head_channels, 256, norm_act=norm, out_stride=opts.output_stride, pooling_size=opts.pooling) if classes is not None: model = IncrementalSegmentationModule(body, head, head_channels, classes=classes, fusion_mode=opts.fusion_mode) else: model = SegmentationModule(body, head, head_channels, opts.num_classes, opts.fusion_mode) return model