Ejemplo n.º 1
0
    def __init__(self, num_classes, layer='50', input_ch=3):
        super(zzResBase, self).__init__()

        self.num_classes = num_classes
        print('resnet' + layer)

        if layer == '18':
            resnet = extended_resnet.resnet18(pretrained=True,
                                              input_ch=input_ch)
        elif layer == '50':
            resnet = extended_resnet.resnet50(pretrained=True,
                                              input_ch=input_ch)
        elif layer == '101':
            resnet = extended_resnet.resnet101(pretrained=True,
                                               input_ch=input_ch)
        elif layer == '152':
            resnet = extended_resnet.resnet152(pretrained=True,
                                               input_ch=input_ch)
        else:
            NotImplementedError

        self.conv1 = resnet.conv1
        self.bn0 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4
Ejemplo n.º 2
0
    def __init__(self, num_classes, layer='50', input_ch=3):
        super(ResFCN, self).__init__()

        self.num_classes = num_classes
        print('resnet' + layer)

        if layer == '18':
            resnet = extended_resnet.resnet18(pretrained=True,
                                              input_ch=input_ch)
        elif layer == '34':
            resnet = extended_resnet.resnet34(pretrained=True,
                                              input_ch=input_ch)
        elif layer == '50':
            resnet = extended_resnet.resnet50(pretrained=True,
                                              input_ch=input_ch)
        elif layer == '101':
            resnet = extended_resnet.resnet101(pretrained=True,
                                               input_ch=input_ch)
        elif layer == '152':
            resnet = extended_resnet.resnet152(pretrained=True,
                                               input_ch=input_ch)
        else:
            NotImplementedError

        self.conv1 = resnet.conv1
        self.bn0 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        self.num_classes = num_classes
        self.upsample1 = Upsample(2048, 1024)
        self.upsample2 = Upsample(1024, 512)
        self.upsample3 = Upsample(512, 64)
        self.upsample4 = Upsample(64, 64)
        self.upsample5 = Upsample(64, 32)

        self.fs1 = Fusion(1024)
        self.fs2 = Fusion(512)
        self.fs3 = Fusion(256)
        self.fs4 = Fusion(64)
        self.fs5 = Fusion(64)

        self.out5 = self._classifier(32)

        self.transformer = nn.Conv2d(256, 64, kernel_size=1)
Ejemplo n.º 3
0
    def __init__(self,
                 base_model='resnet50',
                 input_ch=3,
                 use_dropout_at_layer4=True):
        super(ResBase, self).__init__()

        print(base_model)
        if base_model == 'resnet18':
            resnet = extended_resnet.resnet18(
                pretrained=True,
                input_ch=input_ch,
                use_dropout_at_layer4=use_dropout_at_layer4)
        elif base_model == 'resnet50':
            resnet = extended_resnet.resnet50(
                pretrained=True,
                input_ch=input_ch,
                use_dropout_at_layer4=use_dropout_at_layer4)
        elif base_model == 'resnet101':
            resnet = extended_resnet.resnet101(
                pretrained=True,
                input_ch=input_ch,
                use_dropout_at_layer4=use_dropout_at_layer4)
        elif base_model == 'resnet152':
            resnet = extended_resnet.resnet152(
                pretrained=True,
                input_ch=input_ch,
                use_dropout_at_layer4=use_dropout_at_layer4)
        else:
            raise ValueError("{} is not supported".format(base_model))

        self.conv1 = resnet.conv1
        self.bn0 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4