def _resblock(initial_size, fc, idim=idim, first_resblock=False): if fc: return layers.iResBlock( FCNet( input_shape=initial_size, idim=idim, lipschitz_layer=_lipschitz_layer(True), nhidden=len(kernels.split('-')) - 1, coeff=coeff, domains=domains, codomains=codomains, n_iterations=n_lipschitz_iters, activation_fn=activation_fn, preact=preact, dropout=dropout, sn_atol=sn_atol, sn_rtol=sn_rtol, learn_p=learn_p, ), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, ) else: ks = list(map(int, kernels.split('-'))) if learn_p: _domains = [ nn.Parameter(torch.tensor(0.)) for _ in range(len(ks)) ] _codomains = _domains[1:] + [_domains[0]] else: _domains = domains _codomains = codomains nnet = [] if not first_resblock and preact: if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) nnet.append(ACT_FNS[activation_fn](False)) nnet.append( _lipschitz_layer(fc)(initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) for i, k in enumerate(ks[1:-1]): nnet.append( _lipschitz_layer(fc)(idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append( _lipschitz_layer(fc)(idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) return layers.iResBlock( nn.Sequential(*nnet), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, )
def _resblock(initial_size, fc, idim=idim, first_resblock=False, densenet=densenet, densenet_depth=densenet_depth, densenet_growth=densenet_growth, fc_densenet_growth=fc_densenet_growth, learnable_concat=learnable_concat, lip_coeff=lip_coeff): if fc: return layers.iResBlock( FCNet( input_shape=initial_size, idim=idim, lipschitz_layer=_lipschitz_layer(True), nhidden=len(kernels.split('-')) - 1, coeff=coeff, domains=domains, codomains=codomains, n_iterations=n_lipschitz_iters, activation_fn=activation_fn, preact=preact, dropout=dropout, sn_atol=sn_atol, sn_rtol=sn_rtol, learn_p=learn_p, densenet=densenet, densenet_depth=densenet_depth, densenet_growth=densenet_growth, fc_densenet_growth=fc_densenet_growth, learnable_concat=learnable_concat, lip_coeff=lip_coeff, ), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, ) else: if densenet: ks = list(map(int, kernels.split('-'))) if learn_p: _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))] _codomains = _domains[1:] + [_domains[0]] else: _domains = domains _codomains = codomains # Initializing nnet as empty list nnet = [] total_in_channels = initial_size[0] for i in range(densenet_depth): part_net = [] # Change growth size for CLipSwish: if activation_fn == 'CLipSwish': output_channels = densenet_growth // 2 else: output_channels = densenet_growth part_net.append( _lipschitz_layer(fc)( total_in_channels, output_channels, 3, 1, padding=1, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: part_net.append(layers.MovingBatchNorm2d(densenet_growth)) part_net.append(ACT_FNS[activation_fn](False)) nnet.append( layers.LipschitzDenseLayer( nn.Sequential(*part_net), learnable_concat, lip_coeff ) ) total_in_channels += densenet_growth # Last layer 1x1 nnet.append( _lipschitz_layer(fc)( total_in_channels, initial_size[0], 1, 1, padding=0, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol ) ) return layers.iResBlock( nn.Sequential(*nnet), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, ) # RESNET else: ks = list(map(int, kernels.split('-'))) if learn_p: _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))] _codomains = _domains[1:] + [_domains[0]] else: _domains = domains _codomains = codomains # Initializing nnet as empty list nnet = [] # Change sizes for CLipSwish: if activation_fn == 'CLipSwish': idim_out = idim // 2 if not first_resblock and preact: in_size = initial_size[0] * 2 else: in_size = initial_size[0] else: idim_out = idim in_size = initial_size[0] if not first_resblock and preact: if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) nnet.append(ACT_FNS[activation_fn](False)) nnet.append( _lipschitz_layer(fc)( in_size, idim_out, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) for i, k in enumerate(ks[1:-1]): nnet.append( _lipschitz_layer(fc)( idim, idim_out, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append( _lipschitz_layer(fc)( idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) return layers.iResBlock( nn.Sequential(*nnet), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, )
if __name__ == '__main__': activation_fn = ACTIVATION_FNS[args.act] if args.arch == 'iresnet': dims = [2] + list(map(int, args.dims.split('-'))) + [2] blocks = [] if args.actnorm: blocks.append(layers.ActNorm1d(2)) for _ in range(args.nblocks): blocks.append( layers.iResBlock( build_nnet(dims, activation_fn), n_dist=args.n_dist, n_power_series=args.n_power_series, exact_trace=args.exact_trace, brute_force=args.brute_force, n_samples=args.n_samples, neumann_grad=False, grad_in_forward=False, ) ) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device) elif args.arch == 'realnvp': blocks = [] for _ in range(args.nblocks): blocks.append(layers.CouplingLayer(2, swap=False)) blocks.append(layers.CouplingLayer(2, swap=True)) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
def get_iresblock(initial_size, fc, idim, first_resblock=False, n_power_series=5, kernels='3-1-3', batchnorm=True, preact=False, activation_fn='relu', n_lipschitz_iters=2000, coeff=0.97, sn_atol=1e-3, sn_rtol=1e-3, dropout=True, n_dist='geometric', n_samples=1,n_exact_terms=0,neumann_grad=True, grad_in_forward=False): if fc: return layers.iResBlock( FCNet( input_shape=initial_size, idim=idim, lipschitz_layer=_lipschitz_layer(True), nhidden=len(kernels.split('-')) - 1, coeff=coeff, n_iterations=n_lipschitz_iters, activation_fn=activation_fn, preact=preact, dropout=dropout, sn_atol=sn_atol, sn_rtol=sn_rtol, ), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, ) else: ks = list(map(int, kernels.split('-'))) # kernal size [3, 1, 3] by default nnet = [] ''' architecture: batchnorm conv1d batchnorm conv1d batchnorm dropout conv1d batchnorm ''' if not first_resblock and preact: if batchnorm: nnet.append(layers.MovingBatchNormconv1d(initial_size[0])) nnet.append(ACT_FNS[activation_fn](False)) nnet.append( _lipschitz_layer(fc)( initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNormconv1d(idim)) nnet.append(ACT_FNS[activation_fn](True)) for i, k in enumerate(ks[1:-1]): nnet.append( _lipschitz_layer(fc)( idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNormconv1d(idim)) nnet.append(ACT_FNS[activation_fn](True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append( _lipschitz_layer(fc)( idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNormconv1d(initial_size[0])) return layers.iResBlock( nn.Sequential(*nnet), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, )