示例#1
0
def resnet18_domain(classes=4, domains=30, url=None, **kwargs):
    """Constructs a ResNet-18 model.
    Args:
    	num_classes (int): the number of classes of the classification model
    	domains (int): the number of domains included in the model (#source + #auxuliary)
    """
    # Instantiate original ResNet
    model_origin = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    model_origin.load_state_dict(model_zoo.load_url(url))

    # Instatiate Domain ResNet
    model = DomainResNet(DomainBlock, [2, 2, 2, 2], domains=domains, **kwargs)

    # Copy BN stats and params from the original to the domain-based
    c = 0
    for m_orig in model_origin.named_modules():
        if 'bn' in m_orig[0] or 'downsample.1' in m_orig[0]:
            for m_doms in model.named_modules():
                if m_doms[0] == m_orig[0]:
                    c += 1
                    for i in range(domains):
                        m_doms[1].bns[i].running_var.data[:] = m_orig[
                            1].running_var.data[:]
                        m_doms[1].bns[i].running_mean.data[:] = m_orig[
                            1].running_mean.data[:]
                        m_doms[1].scale.data[i, :] = m_orig[1].weight.data[:]
                        m_doms[1].bias.data[i, :] = m_orig[1].bias.data[:]

        elif 'conv' in m_orig[0] or 'downsample.0' in m_orig[0]:
            for m_doms in model.named_modules():
                if m_doms[0] == m_orig[0]:
                    m_doms[1].weight.data[:] = m_orig[1].weight.data[:]

    # Init classifier
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, classes)
    for m in model.modules():
        if isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.0001)

    return model
示例#2
0
def se_resnet152(num_classes=1_000):
    """Constructs a ResNet-152 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes)
    model.avgpool = nn.AdaptiveAvgPool2d(1)
    return model


if __name__ == "__main__":

    import torch
    model = se_resnet50()

    # 替换网络层
    for name, module in model.named_modules():
        print("layer name:{}, layer instance:{}".format(name, module))
    in_feat_num = model.fc.in_features
    model.fc = nn.Linear(in_feat_num, 102)

    # forward
    fake_img = torch.randn((1, 3, 224, 224))  # batchsize * channel * height * width
    output = model(fake_img)
    print(output.shape)