def build_model(): activation_fn = ACTIVATION_FNS[args.act] dims = [data_dim] + list(map(int, args.dims.split('-'))) + [data_dim] blocks = [] for _ in range(args.nblocks): blocks.append( layers.imBlock( build_nnet(dims, activation_fn), # ACTIVATION_FNS['zero'](), build_nnet(dims, activation_fn), n_dist=args.n_dist, n_power_series=args.n_power_series, exact_trace=False, brute_force=args.brute_force, n_samples=args.n_samples, n_exact_terms=args.n_exact_terms, neumann_grad=False, grad_in_forward=False, # toy data needn't save memory eps_forward=args.epsf)) model = layers.SequentialFlow(blocks).to(device) return model
def _resblock(initial_size, fc, idim=idim, first_resblock=True): if fc: return layers.imBlock( FCNet( input_shape=initial_size, idim=idim, lipschitz_layer=_lipschitz_layer(True), nhidden=len(kernels.split('-')) - 1, coeff=coeff, domains=domains, codomains=codomains, n_iterations=n_lipschitz_iters, activation_fn=activation_fn, preact=preact, dropout=dropout, sn_atol=sn_atol, sn_rtol=sn_rtol, learn_p=learn_p, ), FCNet( input_shape=initial_size, idim=idim, lipschitz_layer=_lipschitz_layer(True), nhidden=len(kernels.split('-')) - 1, coeff=coeff, domains=domains, codomains=codomains, n_iterations=n_lipschitz_iters, activation_fn=activation_fn, preact=preact, dropout=dropout, sn_atol=sn_atol, sn_rtol=sn_rtol, learn_p=learn_p, ), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, ) else: def build_nnet(): ks = list(map(int, kernels.split('-'))) if learn_p: _domains = [ nn.Parameter(torch.tensor(0.)) for _ in range(len(ks)) ] _codomains = _domains[1:] + [_domains[0]] else: _domains = domains _codomains = codomains nnet = [] if not first_resblock and preact: if batchnorm: nnet.append( layers.MovingBatchNorm2d(initial_size[0])) nnet.append(ACT_FNS[activation_fn](False)) nnet.append( _lipschitz_layer(fc)(initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) for i, k in enumerate(ks[1:-1]): nnet.append( _lipschitz_layer(fc)( idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append( _lipschitz_layer(fc)(idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) return nn.Sequential(*nnet) return layers.imBlock( build_nnet(), build_nnet(), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, )
)) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device) elif args.arch == 'implicit': dims = [2] + list(map(int, args.dims.split('-'))) + [2] blocks = [] if args.actnorm: blocks.append(layers.ActNorm1d(2)) for _ in range(args.nblocks): blocks.append( layers.imBlock( build_nnet(dims, activation_fn), build_nnet(dims, activation_fn), n_dist=args.n_dist, n_power_series=args.n_power_series, exact_trace=args.exact_trace, brute_force=args.brute_force, n_samples=args.n_samples, neumann_grad=False, grad_in_forward=False, # toy data needn't save memory )) model = torch.nn.DataParallel(layers.SequentialFlow(blocks).to(device)) elif args.arch == 'realnvp': blocks = [] for _ in range(args.nblocks): blocks.append(layers.CouplingBlock(2, swap=False)) blocks.append(layers.CouplingBlock(2, swap=True)) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device)
def __init__( self, in_planes, hidden, planes, stride=1, n_lipschitz_iters=None, sn_atol=1e-3, sn_rtol=1e-3, ): super(BasicImplicitBlock, self).__init__() coeff = args.coeff self.initialized = False def build_net(): layer = base_layers.get_conv2d nnet = [] nnet.append( layer( in_planes, hidden, kernel_size=3, stride=1, padding=1, bias=False, coeff=coeff, n_iterations=n_lipschitz_iters, domain=2, codomain=2, atol=sn_atol, rtol=sn_rtol, )) nnet.append(ACTIVATION_FNS['relu']()) nnet.append( layer( hidden, in_planes, kernel_size=3, stride=1, padding=1, bias=False, coeff=coeff, n_iterations=n_lipschitz_iters, domain=2, codomain=2, atol=sn_atol, rtol=sn_rtol, )) nnet.append(ACTIVATION_FNS['relu']()) return nn.Sequential(*nnet) self.block = layers.imBlock( build_net(), build_net(), ) self.downsample = nn.Sequential() if stride != 1 or in_planes != self.expansion * planes: self.downsample = nn.Sequential( nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion * planes), ACTIVATION_FNS['relu'](), )