Example #1
0
    def __init__(self,
                 num_filters,
                 time_dependent=False,
                 non_linearity='relu'):
        """
        Block for ConvODEUNet

        Args:
            num_filters (int): number of filters for the conv layers
            time_dependent (bool): whether to concat the time as a feature map before the convs
            non_linearity (str): which non_linearity to use (for options see get_nonlinearity)
        """
        super(ConvODEFunc, self).__init__()
        nf = num_filters
        self.time_dependent = time_dependent
        self.nfe = 0  # Number of function evaluations

        self.norm = nn.InstanceNorm2d(nf)
        if time_dependent:
            self.conv1 = Conv2dTime(nf, nf, kernel_size=3, stride=1, padding=1)
            self.conv2 = Conv2dTime(nf, nf, kernel_size=3, stride=1, padding=1)
        else:
            self.conv1 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1)

        self.non_linearity = get_nonlinearity(non_linearity)
    def __init__(self, num_filters, output_dim=1, time_dependent=False,
                 non_linearity='softplus', tol=1e-3, adjoint=False):
        """
        ConvODEUNet (U-Node in paper)

        Args:
            num_filters (int): number of filters for first conv layer
            output_dim (int): how many feature maps the network outputs
            time_dependent (bool): whether to concat the time as a feature map before the convs
            non_linearity (str): which non_linearity to use (for options see get_nonlinearity)
            tol (float): tolerance to be used for ODE solver
            adjoint (bool): whether to use the adjoint method to calculate the gradients
        """
        super(ConvODEUNet, self).__init__()
        nf = num_filters

        self.input_1x1 = nn.Conv2d(3, nf, 1, 1)

        ode_down1 = ConvODEFunc(nf, time_dependent, non_linearity)
        self.odeblock_down1 = ODEBlock(ode_down1, tol=tol, adjoint=adjoint)
        self.conv_down1_2 = nn.Conv2d(nf, nf*2, 1, 1)

        ode_down2 = ConvODEFunc(nf*2, time_dependent, non_linearity)
        self.odeblock_down2 = ODEBlock(ode_down2, tol=tol, adjoint=adjoint)
        self.conv_down2_3 = nn.Conv2d(nf*2, nf*4, 1, 1)

        ode_down3 = ConvODEFunc(nf*4, time_dependent, non_linearity)
        self.odeblock_down3 = ODEBlock(ode_down3, tol=tol, adjoint=adjoint)
        self.conv_down3_4 = nn.Conv2d(nf*4, nf*8, 1, 1)

        ode_down4 = ConvODEFunc(nf*8, time_dependent, non_linearity)
        self.odeblock_down4 = ODEBlock(ode_down4,  tol=tol, adjoint=adjoint)
        self.conv_down4_embed = nn.Conv2d(nf*8, nf*16, 1, 1)

        ode_embed = ConvODEFunc(nf*16, time_dependent, non_linearity)
        self.odeblock_embedding = ODEBlock(ode_embed,  tol=tol, adjoint=adjoint)

        self.conv_up_embed_1 = nn.Conv2d(nf*16+nf*8, nf*8, 1, 1)
        ode_up1 = ConvODEFunc(nf*8, time_dependent, non_linearity)
        self.odeblock_up1 = ODEBlock(ode_up1, tol=tol, adjoint=adjoint)

        self.conv_up1_2 = nn.Conv2d(nf*8+nf*4, nf*4, 1, 1)
        ode_up2 = ConvODEFunc(nf*4, time_dependent, non_linearity)
        self.odeblock_up2 = ODEBlock(ode_up2, tol=tol, adjoint=adjoint)

        self.conv_up2_3 = nn.Conv2d(nf*4+nf*2, nf*2, 1, 1)
        ode_up3 = ConvODEFunc(nf*2, time_dependent, non_linearity)
        self.odeblock_up3 = ODEBlock(ode_up3, tol=tol, adjoint=adjoint)

        self.conv_up3_4 = nn.Conv2d(nf*2+nf, nf, 1, 1)
        ode_up4 = ConvODEFunc(nf, time_dependent, non_linearity)
        self.odeblock_up4 = ODEBlock(ode_up4, tol=tol, adjoint=adjoint)

        self.classifier = nn.Conv2d(nf, output_dim, 1)

        self.non_linearity = get_nonlinearity(non_linearity)
    def __init__(self, num_filters, non_linearity='relu'):
        """
        Block for ConvResUNet

        Args:
            num_filters (int): number of filters for the conv layers
            non_linearity (str): which non_linearity to use (for options see get_nonlinearity)
        """
        super(ConvResFunc, self).__init__()

        self.conv1 = nn.Conv2d(num_filters, num_filters, kernel_size=3, stride=1, padding=1)
        self.norm = nn.InstanceNorm2d(2, num_filters)
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, stride=1, padding=1)

        self.non_linearity = get_nonlinearity(non_linearity)
Example #4
0
    def __init__(self, num_filters, output_dim=1, non_linearity='softplus'):
        """
        ConvResUNet (U-Node in paper)

        Args:
            num_filters (int): number of filters for first conv layer
            output_dim (int): how many feature maps the network outputs
            non_linearity (str): which non_linearity to use (for options see get_nonlinearity)
        """
        super(ConvResUNet, self).__init__()
        self.output_dim = output_dim

        self.input_1x1 = nn.Conv2d(3, num_filters, 1, 1)

        self.block_down1 = ConvResFunc(num_filters, non_linearity)
        self.conv_down1_2 = nn.Conv2d(num_filters, num_filters * 2, 1, 1)
        self.block_down2 = ConvResFunc(num_filters * 2, non_linearity)
        self.conv_down2_3 = nn.Conv2d(num_filters * 2, num_filters * 4, 1, 1)
        self.block_down3 = ConvResFunc(num_filters * 4, non_linearity)
        self.conv_down3_4 = nn.Conv2d(num_filters * 4, num_filters * 8, 1, 1)
        self.block_down4 = ConvResFunc(num_filters * 8, non_linearity)
        self.conv_down4_embed = nn.Conv2d(num_filters * 8, num_filters * 16, 1,
                                          1)

        self.block_embedding = ConvResFunc(num_filters * 16, non_linearity)

        self.conv_up_embed_1 = nn.Conv2d(num_filters * 16 + num_filters * 8,
                                         num_filters * 8, 1, 1)
        self.block_up1 = ConvResFunc(num_filters * 8, non_linearity)
        self.conv_up1_2 = nn.Conv2d(num_filters * 8 + num_filters * 4,
                                    num_filters * 4, 1, 1)
        self.block_up2 = ConvResFunc(num_filters * 4, non_linearity)
        self.conv_up2_3 = nn.Conv2d(num_filters * 4 + num_filters * 2,
                                    num_filters * 2, 1, 1)
        self.block_up3 = ConvResFunc(num_filters * 2, non_linearity)
        self.conv_up3_4 = nn.Conv2d(num_filters * 2 + num_filters, num_filters,
                                    1, 1)
        self.block_up4 = ConvResFunc(num_filters, non_linearity)

        self.classifier = nn.Conv2d(num_filters, self.output_dim, 1)

        self.non_linearity = get_nonlinearity(non_linearity)