예제 #1
0
def get_upsample_layer(spatial_dims: int, in_channels: int, upsample_mode: str = "trilinear", scale_factor: int = 2):
    up_module: nn.Module
    if upsample_mode == "transpose":
        up_module = UpSample(spatial_dims, in_channels, scale_factor=scale_factor, with_conv=True,)
    else:
        upsample_mode = "bilinear" if spatial_dims == 2 else "trilinear"
        up_module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode, align_corners=False)
    return up_module
def get_upsample_layer(
    spatial_dims: int, in_channels: int, upsample_mode: Union[UpsampleMode, str] = "nontrainable", scale_factor: int = 2
):
    return UpSample(
        spatial_dims=spatial_dims,
        in_channels=in_channels,
        out_channels=in_channels,
        scale_factor=scale_factor,
        mode=upsample_mode,
        interp_mode=InterpolateMode.LINEAR,
        align_corners=False,
    )
예제 #3
0
    def __init__(self,
                 out_channels: int = 1,
                 upsample_mode: str = "bilinear",
                 pretrained: bool = True,
                 progress: bool = True):
        super(FCN, self).__init__()

        conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2]

        self.upsample_mode = upsample_mode
        self.conv2d_type = conv2d_type
        self.out_channels = out_channels
        resnet = models.resnet50(pretrained=pretrained, progress=progress)

        self.conv1 = resnet.conv1
        self.bn0 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        self.gcn1 = GCN(2048, self.out_channels)
        self.gcn2 = GCN(1024, self.out_channels)
        self.gcn3 = GCN(512, self.out_channels)
        self.gcn4 = GCN(64, self.out_channels)
        self.gcn5 = GCN(64, self.out_channels)

        self.refine1 = Refine(self.out_channels)
        self.refine2 = Refine(self.out_channels)
        self.refine3 = Refine(self.out_channels)
        self.refine4 = Refine(self.out_channels)
        self.refine5 = Refine(self.out_channels)
        self.refine6 = Refine(self.out_channels)
        self.refine7 = Refine(self.out_channels)
        self.refine8 = Refine(self.out_channels)
        self.refine9 = Refine(self.out_channels)
        self.refine10 = Refine(self.out_channels)
        self.transformer = self.conv2d_type(in_channels=256,
                                            out_channels=64,
                                            kernel_size=1)

        if self.upsample_mode == "transpose":
            self.up_conv = UpSample(
                dimensions=2,
                in_channels=self.out_channels,
                out_channels=self.out_channels,
                scale_factor=2,
                with_conv=True,
            )