def __init__(self, dim_in=1, n_classes=8, dim_latent=32, num_res=4, scale_factor=2, activation='softmax'): super(UNOdeMSegNet, self).__init__() if activation == 'softmax': self.activation = nn.Softmax(1) else: self.activation = None self.num_res = num_res # expand from image to higher-dim feature maps self.expand_cfn = nn.Sequential( nn.Conv2d(dim_in, dim_latent, kernel_size=3, stride=1, padding=1), actfn2nn(ActivationFn.RELU), nn.Conv2d(dim_latent, dim_latent, kernel_size=3, stride=1, padding=1), actfn2nn(ActivationFn.RELU), ) # downsample to coarser resolution self.downsample = MultiResDownsample(num_res, scale_factor) # dynamics for each resolution mres_cdfn = MultiResDMap(num_res, dim_latent, scale_factor) self.cdfn_block = ODEBlock(mres_cdfn, ODEBlockSpec(use_adjoint=True)) # classify the multi-res feature self.classifier = MultiResMLP(num_res, dim_latent, n_classes)
def __init__(self, dim_in, dim_out): super(MSegNet, self).__init__() self.dim_out = dim_out # expand from image to higher-dim feature maps EXDIM = 64 self.expand_cfn = nn.Sequential( nn.Conv2d(dim_in, EXDIM, kernel_size=3, stride=1, padding=1), actfn2nn(ActivationFn.RELU), nn.Conv2d(EXDIM, EXDIM, kernel_size=3, stride=1, padding=1), actfn2nn(ActivationFn.RELU), ) # convolution dynamics MIDDIM = 128 cdfn = SequentialListDMap([ ConcatSquashConv2d(dim_out + EXDIM, MIDDIM, actfn=ActivationFn.TANH), ConcatSquashConv2d(MIDDIM, MIDDIM, actfn=ActivationFn.TANH), ConcatSquashConv2d(MIDDIM, dim_out + EXDIM, actfn=ActivationFn.NONE), ]) self.cdfn_block = ODEBlock(cdfn, ODEBlockSpec(use_adjoint=True)) # log p(z) using only first d elements self.logpdf = make_truncate_logpz(self.dim_out, dim_reduce=None)
def __init__(self, dim, map_block, actfn, ngroups=16): super(BasicResNetBlock, self).__init__() if ngroups > 1: self._norm1 = nn.GroupNorm(ngroups, dim, eps=1e-4) else: self._norm1 = util.actfn2nn(ActivationFn.NONE) self._relu1 = nn.ReLU() self._map1 = map_block self._actfn = util.actfn2nn(actfn)
def __init__(self, spec): super(LinearDMap, self).__init__() # sanity check assert isinstance(spec, LinearSpec) # time dependent self.use_time = spec.use_time in_dim_add = 1 if self.use_time else 0 # parse spec into linear layer nets = [] nets.append( nn.Linear(spec.in_dim + in_dim_add, spec.out_dim, bias=spec.bias)) nets.append(util.actfn2nn(spec.act_fn)) self.net = nn.Sequential(*nets)
def __init__(self, spec): super(DynamicMap, self).__init__() # sanity check assert isinstance(spec, ConvSpec) # time dependent self.use_time = spec.use_time in_channel_add = 1 if self.use_time else 0 # parse spec into linear layer nets = [] nets.append( nn.Conv2d( spec.in_channel + in_channel_add, spec.out_channel, kernel_size=spec.kernel_size, stride=spec.stride, padding=spec.padding, )) nets.append(util.actfn2nn(spec.act_fn)) self.net = nn.Sequential(*nets)
def __init__(self, num_res, dim_in, dim_out, softmax_activation=False): super(MultiResMLP, self).__init__() # separate mlp for each resolution self.softmax_activation = softmax_activation MIDDIM = dim_in * 2 self.mres_nets = nn.ModuleList() for _ in range(num_res): self.mres_nets.append( nn.Sequential( nn.Conv2d(dim_in, MIDDIM, kernel_size=1, stride=1, padding=0), actfn2nn(ActivationFn.RELU), nn.Conv2d(MIDDIM, dim_out, kernel_size=1, stride=1, padding=0), ))
def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=1, dilation=1, groups=1, bias=True, transpose=False, actfn=ActivationFn.NONE): super(ConcatSquashConv2d, self).__init__() module = nn.ConvTranspose2d if transpose else nn.Conv2d self._layer = module(dim_in, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self._hyper_gate = nn.Linear(1, dim_out) self._hyper_bias = nn.Linear(1, dim_out, bias=False) self._actfn = util.actfn2nn(actfn)
def __init__(self, dim_in, dim_out, actfn): super(ConcatSquashLinear, self).__init__() self._layer = nn.Linear(dim_in, dim_out) self._hyper_bias = nn.Linear(1, dim_out, bias=False) self._hyper_gate = nn.Linear(1, dim_out) self._actfn = util.actfn2nn(actfn)