def get_upsample_layer(spatial_dims: int, in_channels: int, upsample_mode: str = "trilinear", scale_factor: int = 2): up_module: nn.Module if upsample_mode == "transpose": up_module = UpSample(spatial_dims, in_channels, scale_factor=scale_factor, with_conv=True,) else: upsample_mode = "bilinear" if spatial_dims == 2 else "trilinear" up_module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode, align_corners=False) return up_module
def get_upsample_layer( spatial_dims: int, in_channels: int, upsample_mode: Union[UpsampleMode, str] = "nontrainable", scale_factor: int = 2 ): return UpSample( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=in_channels, scale_factor=scale_factor, mode=upsample_mode, interp_mode=InterpolateMode.LINEAR, align_corners=False, )
def __init__(self, out_channels: int = 1, upsample_mode: str = "bilinear", pretrained: bool = True, progress: bool = True): super(FCN, self).__init__() conv2d_type: Type[nn.Conv2d] = Conv[Conv.CONV, 2] self.upsample_mode = upsample_mode self.conv2d_type = conv2d_type self.out_channels = out_channels resnet = models.resnet50(pretrained=pretrained, progress=progress) self.conv1 = resnet.conv1 self.bn0 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 self.gcn1 = GCN(2048, self.out_channels) self.gcn2 = GCN(1024, self.out_channels) self.gcn3 = GCN(512, self.out_channels) self.gcn4 = GCN(64, self.out_channels) self.gcn5 = GCN(64, self.out_channels) self.refine1 = Refine(self.out_channels) self.refine2 = Refine(self.out_channels) self.refine3 = Refine(self.out_channels) self.refine4 = Refine(self.out_channels) self.refine5 = Refine(self.out_channels) self.refine6 = Refine(self.out_channels) self.refine7 = Refine(self.out_channels) self.refine8 = Refine(self.out_channels) self.refine9 = Refine(self.out_channels) self.refine10 = Refine(self.out_channels) self.transformer = self.conv2d_type(in_channels=256, out_channels=64, kernel_size=1) if self.upsample_mode == "transpose": self.up_conv = UpSample( dimensions=2, in_channels=self.out_channels, out_channels=self.out_channels, scale_factor=2, with_conv=True, )