Exemplo n.º 1
0
    def __init__(self,
                 dim_in=1,
                 n_classes=8,
                 dim_latent=32,
                 num_res=4,
                 scale_factor=2,
                 activation='softmax'):
        super(UNOdeMSegNet, self).__init__()
        if activation == 'softmax':
            self.activation = nn.Softmax(1)
        else:
            self.activation = None
        self.num_res = num_res

        # expand from image to higher-dim feature maps
        self.expand_cfn = nn.Sequential(
            nn.Conv2d(dim_in, dim_latent, kernel_size=3, stride=1, padding=1),
            actfn2nn(ActivationFn.RELU),
            nn.Conv2d(dim_latent,
                      dim_latent,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            actfn2nn(ActivationFn.RELU),
        )

        # downsample to coarser resolution
        self.downsample = MultiResDownsample(num_res, scale_factor)

        # dynamics for each resolution
        mres_cdfn = MultiResDMap(num_res, dim_latent, scale_factor)
        self.cdfn_block = ODEBlock(mres_cdfn, ODEBlockSpec(use_adjoint=True))

        # classify the multi-res feature
        self.classifier = MultiResMLP(num_res, dim_latent, n_classes)
Exemplo n.º 2
0
    def __init__(self, dim_in, dim_out):
        super(MSegNet, self).__init__()
        self.dim_out = dim_out

        # expand from image to higher-dim feature maps
        EXDIM = 64
        self.expand_cfn = nn.Sequential(
            nn.Conv2d(dim_in, EXDIM, kernel_size=3, stride=1, padding=1),
            actfn2nn(ActivationFn.RELU),
            nn.Conv2d(EXDIM, EXDIM, kernel_size=3, stride=1, padding=1),
            actfn2nn(ActivationFn.RELU),
        )

        # convolution dynamics
        MIDDIM = 128
        cdfn = SequentialListDMap([
            ConcatSquashConv2d(dim_out + EXDIM,
                               MIDDIM,
                               actfn=ActivationFn.TANH),
            ConcatSquashConv2d(MIDDIM, MIDDIM, actfn=ActivationFn.TANH),
            ConcatSquashConv2d(MIDDIM,
                               dim_out + EXDIM,
                               actfn=ActivationFn.NONE),
        ])
        self.cdfn_block = ODEBlock(cdfn, ODEBlockSpec(use_adjoint=True))

        # log p(z) using only first d elements
        self.logpdf = make_truncate_logpz(self.dim_out, dim_reduce=None)
Exemplo n.º 3
0
    def __init__(self, dim, map_block, actfn, ngroups=16):
        super(BasicResNetBlock, self).__init__()

        if ngroups > 1:
            self._norm1 = nn.GroupNorm(ngroups, dim, eps=1e-4)
        else:
            self._norm1 = util.actfn2nn(ActivationFn.NONE)
        self._relu1 = nn.ReLU()
        self._map1 = map_block
        self._actfn = util.actfn2nn(actfn)
Exemplo n.º 4
0
    def __init__(self, spec):
        super(LinearDMap, self).__init__()

        # sanity check
        assert isinstance(spec, LinearSpec)

        # time dependent
        self.use_time = spec.use_time
        in_dim_add = 1 if self.use_time else 0

        # parse spec into linear layer
        nets = []
        nets.append(
            nn.Linear(spec.in_dim + in_dim_add, spec.out_dim, bias=spec.bias))
        nets.append(util.actfn2nn(spec.act_fn))
        self.net = nn.Sequential(*nets)
Exemplo n.º 5
0
    def __init__(self, spec):
        super(DynamicMap, self).__init__()

        # sanity check
        assert isinstance(spec, ConvSpec)

        # time dependent
        self.use_time = spec.use_time
        in_channel_add = 1 if self.use_time else 0

        # parse spec into linear layer
        nets = []
        nets.append(
            nn.Conv2d(
                spec.in_channel + in_channel_add,
                spec.out_channel,
                kernel_size=spec.kernel_size,
                stride=spec.stride,
                padding=spec.padding,
            ))
        nets.append(util.actfn2nn(spec.act_fn))
        self.net = nn.Sequential(*nets)
Exemplo n.º 6
0
    def __init__(self, num_res, dim_in, dim_out, softmax_activation=False):
        super(MultiResMLP, self).__init__()

        # separate mlp for each resolution
        self.softmax_activation = softmax_activation
        MIDDIM = dim_in * 2
        self.mres_nets = nn.ModuleList()
        for _ in range(num_res):
            self.mres_nets.append(
                nn.Sequential(
                    nn.Conv2d(dim_in,
                              MIDDIM,
                              kernel_size=1,
                              stride=1,
                              padding=0),
                    actfn2nn(ActivationFn.RELU),
                    nn.Conv2d(MIDDIM,
                              dim_out,
                              kernel_size=1,
                              stride=1,
                              padding=0),
                ))
Exemplo n.º 7
0
 def __init__(self,
              dim_in,
              dim_out,
              ksize=3,
              stride=1,
              padding=1,
              dilation=1,
              groups=1,
              bias=True,
              transpose=False,
              actfn=ActivationFn.NONE):
     super(ConcatSquashConv2d, self).__init__()
     module = nn.ConvTranspose2d if transpose else nn.Conv2d
     self._layer = module(dim_in,
                          dim_out,
                          kernel_size=ksize,
                          stride=stride,
                          padding=padding,
                          dilation=dilation,
                          groups=groups,
                          bias=bias)
     self._hyper_gate = nn.Linear(1, dim_out)
     self._hyper_bias = nn.Linear(1, dim_out, bias=False)
     self._actfn = util.actfn2nn(actfn)
Exemplo n.º 8
0
 def __init__(self, dim_in, dim_out, actfn):
     super(ConcatSquashLinear, self).__init__()
     self._layer = nn.Linear(dim_in, dim_out)
     self._hyper_bias = nn.Linear(1, dim_out, bias=False)
     self._hyper_gate = nn.Linear(1, dim_out)
     self._actfn = util.actfn2nn(actfn)