def __init__(self, num_filters, time_dependent=False, non_linearity='relu'): """ Block for ConvODEUNet Args: num_filters (int): number of filters for the conv layers time_dependent (bool): whether to concat the time as a feature map before the convs non_linearity (str): which non_linearity to use (for options see get_nonlinearity) """ super(ConvODEFunc, self).__init__() nf = num_filters self.time_dependent = time_dependent self.nfe = 0 # Number of function evaluations self.norm = nn.InstanceNorm2d(nf) if time_dependent: self.conv1 = Conv2dTime(nf, nf, kernel_size=3, stride=1, padding=1) self.conv2 = Conv2dTime(nf, nf, kernel_size=3, stride=1, padding=1) else: self.conv1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1) self.non_linearity = get_nonlinearity(non_linearity)
def __init__(self, num_filters, output_dim=1, time_dependent=False, non_linearity='softplus', tol=1e-3, adjoint=False): """ ConvODEUNet (U-Node in paper) Args: num_filters (int): number of filters for first conv layer output_dim (int): how many feature maps the network outputs time_dependent (bool): whether to concat the time as a feature map before the convs non_linearity (str): which non_linearity to use (for options see get_nonlinearity) tol (float): tolerance to be used for ODE solver adjoint (bool): whether to use the adjoint method to calculate the gradients """ super(ConvODEUNet, self).__init__() nf = num_filters self.input_1x1 = nn.Conv2d(3, nf, 1, 1) ode_down1 = ConvODEFunc(nf, time_dependent, non_linearity) self.odeblock_down1 = ODEBlock(ode_down1, tol=tol, adjoint=adjoint) self.conv_down1_2 = nn.Conv2d(nf, nf*2, 1, 1) ode_down2 = ConvODEFunc(nf*2, time_dependent, non_linearity) self.odeblock_down2 = ODEBlock(ode_down2, tol=tol, adjoint=adjoint) self.conv_down2_3 = nn.Conv2d(nf*2, nf*4, 1, 1) ode_down3 = ConvODEFunc(nf*4, time_dependent, non_linearity) self.odeblock_down3 = ODEBlock(ode_down3, tol=tol, adjoint=adjoint) self.conv_down3_4 = nn.Conv2d(nf*4, nf*8, 1, 1) ode_down4 = ConvODEFunc(nf*8, time_dependent, non_linearity) self.odeblock_down4 = ODEBlock(ode_down4, tol=tol, adjoint=adjoint) self.conv_down4_embed = nn.Conv2d(nf*8, nf*16, 1, 1) ode_embed = ConvODEFunc(nf*16, time_dependent, non_linearity) self.odeblock_embedding = ODEBlock(ode_embed, tol=tol, adjoint=adjoint) self.conv_up_embed_1 = nn.Conv2d(nf*16+nf*8, nf*8, 1, 1) ode_up1 = ConvODEFunc(nf*8, time_dependent, non_linearity) self.odeblock_up1 = ODEBlock(ode_up1, tol=tol, adjoint=adjoint) self.conv_up1_2 = nn.Conv2d(nf*8+nf*4, nf*4, 1, 1) ode_up2 = ConvODEFunc(nf*4, time_dependent, non_linearity) self.odeblock_up2 = ODEBlock(ode_up2, tol=tol, adjoint=adjoint) self.conv_up2_3 = nn.Conv2d(nf*4+nf*2, nf*2, 1, 1) ode_up3 = ConvODEFunc(nf*2, time_dependent, non_linearity) self.odeblock_up3 = ODEBlock(ode_up3, tol=tol, adjoint=adjoint) self.conv_up3_4 = nn.Conv2d(nf*2+nf, nf, 1, 1) ode_up4 = ConvODEFunc(nf, time_dependent, non_linearity) self.odeblock_up4 = ODEBlock(ode_up4, tol=tol, adjoint=adjoint) self.classifier = nn.Conv2d(nf, output_dim, 1) self.non_linearity = get_nonlinearity(non_linearity)
def __init__(self, num_filters, non_linearity='relu'): """ Block for ConvResUNet Args: num_filters (int): number of filters for the conv layers non_linearity (str): which non_linearity to use (for options see get_nonlinearity) """ super(ConvResFunc, self).__init__() self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, stride=1, padding=1) self.norm = nn.InstanceNorm2d(2, num_filters) self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, stride=1, padding=1) self.non_linearity = get_nonlinearity(non_linearity)
def __init__(self, num_filters, output_dim=1, non_linearity='softplus'): """ ConvResUNet (U-Node in paper) Args: num_filters (int): number of filters for first conv layer output_dim (int): how many feature maps the network outputs non_linearity (str): which non_linearity to use (for options see get_nonlinearity) """ super(ConvResUNet, self).__init__() self.output_dim = output_dim self.input_1x1 = nn.Conv2d(3, num_filters, 1, 1) self.block_down1 = ConvResFunc(num_filters, non_linearity) self.conv_down1_2 = nn.Conv2d(num_filters, num_filters * 2, 1, 1) self.block_down2 = ConvResFunc(num_filters * 2, non_linearity) self.conv_down2_3 = nn.Conv2d(num_filters * 2, num_filters * 4, 1, 1) self.block_down3 = ConvResFunc(num_filters * 4, non_linearity) self.conv_down3_4 = nn.Conv2d(num_filters * 4, num_filters * 8, 1, 1) self.block_down4 = ConvResFunc(num_filters * 8, non_linearity) self.conv_down4_embed = nn.Conv2d(num_filters * 8, num_filters * 16, 1, 1) self.block_embedding = ConvResFunc(num_filters * 16, non_linearity) self.conv_up_embed_1 = nn.Conv2d(num_filters * 16 + num_filters * 8, num_filters * 8, 1, 1) self.block_up1 = ConvResFunc(num_filters * 8, non_linearity) self.conv_up1_2 = nn.Conv2d(num_filters * 8 + num_filters * 4, num_filters * 4, 1, 1) self.block_up2 = ConvResFunc(num_filters * 4, non_linearity) self.conv_up2_3 = nn.Conv2d(num_filters * 4 + num_filters * 2, num_filters * 2, 1, 1) self.block_up3 = ConvResFunc(num_filters * 2, non_linearity) self.conv_up3_4 = nn.Conv2d(num_filters * 2 + num_filters, num_filters, 1, 1) self.block_up4 = ConvResFunc(num_filters, non_linearity) self.classifier = nn.Conv2d(num_filters, self.output_dim, 1) self.non_linearity = get_nonlinearity(non_linearity)