def resnet18_domain(classes=4, domains=30, url=None, **kwargs): """Constructs a ResNet-18 model. Args: num_classes (int): the number of classes of the classification model domains (int): the number of domains included in the model (#source + #auxuliary) """ # Instantiate original ResNet model_origin = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) model_origin.load_state_dict(model_zoo.load_url(url)) # Instatiate Domain ResNet model = DomainResNet(DomainBlock, [2, 2, 2, 2], domains=domains, **kwargs) # Copy BN stats and params from the original to the domain-based c = 0 for m_orig in model_origin.named_modules(): if 'bn' in m_orig[0] or 'downsample.1' in m_orig[0]: for m_doms in model.named_modules(): if m_doms[0] == m_orig[0]: c += 1 for i in range(domains): m_doms[1].bns[i].running_var.data[:] = m_orig[ 1].running_var.data[:] m_doms[1].bns[i].running_mean.data[:] = m_orig[ 1].running_mean.data[:] m_doms[1].scale.data[i, :] = m_orig[1].weight.data[:] m_doms[1].bias.data[i, :] = m_orig[1].bias.data[:] elif 'conv' in m_orig[0] or 'downsample.0' in m_orig[0]: for m_doms in model.named_modules(): if m_doms[0] == m_orig[0]: m_doms[1].weight.data[:] = m_orig[1].weight.data[:] # Init classifier num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, classes) for m in model.modules(): if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.0001) return model
def se_resnet152(num_classes=1_000): """Constructs a ResNet-152 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes) model.avgpool = nn.AdaptiveAvgPool2d(1) return model if __name__ == "__main__": import torch model = se_resnet50() # 替换网络层 for name, module in model.named_modules(): print("layer name:{}, layer instance:{}".format(name, module)) in_feat_num = model.fc.in_features model.fc = nn.Linear(in_feat_num, 102) # forward fake_img = torch.randn((1, 3, 224, 224)) # batchsize * channel * height * width output = model(fake_img) print(output.shape)