def _resblock(initial_size, fc, idim=idim, first_resblock=False): if fc: nonloc_scope.swap = not nonloc_scope.swap return layers.CouplingBlock( initial_size[0], FCNet( input_shape=initial_size, idim=idim, lipschitz_layer=_weight_layer(True), nhidden=len(kernels.split('-')) - 1, activation_fn=activation_fn, preact=preact, dropout=dropout, coeff=None, domains=None, codomains=None, n_iterations=None, sn_atol=None, sn_rtol=None, learn_p=None, div_in=2, ), swap=nonloc_scope.swap, ) else: ks = list(map(int, kernels.split('-'))) if init_layer is None: _block = layers.ChannelCouplingBlock _mask_type = 'channel' div_in = 2 mult_out = 1 else: _block = layers.MaskedCouplingBlock _mask_type = 'checkerboard' div_in = 1 mult_out = 2 nonloc_scope.swap = not nonloc_scope.swap _mask_type += '1' if nonloc_scope.swap else '0' nnet = [] if not first_resblock and preact: if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) nnet.append(ACT_FNS[activation_fn](False)) nnet.append(_weight_layer(fc)(initial_size[0] // div_in, idim, ks[0], 1, ks[0] // 2)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) for i, k in enumerate(ks[1:-1]): nnet.append(_weight_layer(fc)(idim, idim, k, 1, k // 2)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append(_weight_layer(fc)(idim, initial_size[0] * mult_out, ks[-1], 1, ks[-1] // 2)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) return _block(initial_size[0], nn.Sequential(*nnet), mask_type=_mask_type)
def create_cnf_model(args, data_shape, regularization_fns): hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf chain = [layers.LogitTransform(alpha=args.cnf_alpha) ] if args.cnf_alpha > 0 else [layers.ZeroMeanTransform()] chain = chain + [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: chain.append(layers.MovingBatchNorm2d(data_shape[0])) model = layers.SequentialFlow(chain) return model
def _resblock(initial_size, fc, idim=idim, first_resblock=False): if fc: return layers.iResBlock( FCNet( input_shape=initial_size, idim=idim, lipschitz_layer=_lipschitz_layer(True), nhidden=len(kernels.split('-')) - 1, coeff=coeff, domains=domains, codomains=codomains, n_iterations=n_lipschitz_iters, activation_fn=activation_fn, preact=preact, dropout=dropout, sn_atol=sn_atol, sn_rtol=sn_rtol, learn_p=learn_p, ), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, ) else: ks = list(map(int, kernels.split('-'))) if learn_p: _domains = [ nn.Parameter(torch.tensor(0.)) for _ in range(len(ks)) ] _codomains = _domains[1:] + [_domains[0]] else: _domains = domains _codomains = codomains nnet = [] if not first_resblock and preact: if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) nnet.append(ACT_FNS[activation_fn](False)) nnet.append( _lipschitz_layer(fc)(initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) for i, k in enumerate(ks[1:-1]): nnet.append( _lipschitz_layer(fc)(idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append( _lipschitz_layer(fc)(idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) return layers.iResBlock( nn.Sequential(*nnet), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, )
def build_nnet(): ks = list(map(int, kernels.split('-'))) if learn_p: _domains = [ nn.Parameter(torch.tensor(0.)) for _ in range(len(ks)) ] _codomains = _domains[1:] + [_domains[0]] else: _domains = domains _codomains = codomains nnet = [] if not first_resblock and preact: if batchnorm: nnet.append( layers.MovingBatchNorm2d(initial_size[0])) nnet.append(ACT_FNS[activation_fn](False)) nnet.append( _lipschitz_layer(fc)(initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) for i, k in enumerate(ks[1:-1]): nnet.append( _lipschitz_layer(fc)( idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append( _lipschitz_layer(fc)(idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) return nn.Sequential(*nnet)
def create_model(args, data_shape, regularization_fns): hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) if args.multiscale: model = odenvp.ODENVP( (args.batch_size, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, nonlinearity=args.nonlinearity, alpha=args.alpha, cnf_kwargs={ "T": args.time_length, "train_T": args.train_T, "regularization_fns": regularization_fns }, ) elif args.parallel: model = multiscale_parallel.MultiscaleParallelCNF( (args.batch_size, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, alpha=args.alpha, time_length=args.time_length, ) else: if args.autoencode: def build_cnf(): autoencoder_diffeq = layers.AutoencoderDiffEqNet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.AutoencoderODEfunc( autoencoder_diffeq=autoencoder_diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, regularization_fns=regularization_fns, solver=args.solver, ) return cnf else: def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf chain = [layers.LogitTransform(alpha=args.alpha) ] if args.alpha > 0 else [layers.ZeroMeanTransform()] chain = chain + [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: chain.append(layers.MovingBatchNorm2d(data_shape[0])) model = layers.SequentialFlow(chain) return model
def build_model(args, state_dict): # load dataset train_loader, test_loader, data_shape = get_dataset(args) hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) # neural net that parameterizes the velocity field if args.autoencode: def build_cnf(): autoencoder_diffeq = layers.AutoencoderDiffEqNet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.AutoencoderODEfunc( autoencoder_diffeq=autoencoder_diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, solver=args.solver, ) return cnf else: def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, solver=args.solver, ) return cnf chain = [layers.LogitTransform(alpha=args.alpha), build_cnf()] if args.batch_norm: chain.append(layers.MovingBatchNorm2d(data_shape[0])) model = layers.SequentialFlow(chain) if args.spectral_norm: add_spectral_norm(model) model.load_state_dict(state_dict) return model, test_loader.dataset
def _resblock(initial_size, fc, idim=idim, first_resblock=False, densenet=densenet, densenet_depth=densenet_depth, densenet_growth=densenet_growth, fc_densenet_growth=fc_densenet_growth, learnable_concat=learnable_concat, lip_coeff=lip_coeff): if fc: return layers.iResBlock( FCNet( input_shape=initial_size, idim=idim, lipschitz_layer=_lipschitz_layer(True), nhidden=len(kernels.split('-')) - 1, coeff=coeff, domains=domains, codomains=codomains, n_iterations=n_lipschitz_iters, activation_fn=activation_fn, preact=preact, dropout=dropout, sn_atol=sn_atol, sn_rtol=sn_rtol, learn_p=learn_p, densenet=densenet, densenet_depth=densenet_depth, densenet_growth=densenet_growth, fc_densenet_growth=fc_densenet_growth, learnable_concat=learnable_concat, lip_coeff=lip_coeff, ), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, ) else: if densenet: ks = list(map(int, kernels.split('-'))) if learn_p: _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))] _codomains = _domains[1:] + [_domains[0]] else: _domains = domains _codomains = codomains # Initializing nnet as empty list nnet = [] total_in_channels = initial_size[0] for i in range(densenet_depth): part_net = [] # Change growth size for CLipSwish: if activation_fn == 'CLipSwish': output_channels = densenet_growth // 2 else: output_channels = densenet_growth part_net.append( _lipschitz_layer(fc)( total_in_channels, output_channels, 3, 1, padding=1, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: part_net.append(layers.MovingBatchNorm2d(densenet_growth)) part_net.append(ACT_FNS[activation_fn](False)) nnet.append( layers.LipschitzDenseLayer( nn.Sequential(*part_net), learnable_concat, lip_coeff ) ) total_in_channels += densenet_growth # Last layer 1x1 nnet.append( _lipschitz_layer(fc)( total_in_channels, initial_size[0], 1, 1, padding=0, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol ) ) return layers.iResBlock( nn.Sequential(*nnet), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, ) # RESNET else: ks = list(map(int, kernels.split('-'))) if learn_p: _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))] _codomains = _domains[1:] + [_domains[0]] else: _domains = domains _codomains = codomains # Initializing nnet as empty list nnet = [] # Change sizes for CLipSwish: if activation_fn == 'CLipSwish': idim_out = idim // 2 if not first_resblock and preact: in_size = initial_size[0] * 2 else: in_size = initial_size[0] else: idim_out = idim in_size = initial_size[0] if not first_resblock and preact: if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) nnet.append(ACT_FNS[activation_fn](False)) nnet.append( _lipschitz_layer(fc)( in_size, idim_out, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) for i, k in enumerate(ks[1:-1]): nnet.append( _lipschitz_layer(fc)( idim, idim_out, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append( _lipschitz_layer(fc)( idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) return layers.iResBlock( nn.Sequential(*nnet), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, )