def construct_model(): if args.nf: chain = [] for i in range(args.depth): chain.append(layers.PlanarFlow(2)) return layers.SequentialFlow(chain) else: chain = [] for i in range(args.depth): if args.glow: chain.append(layers.BruteForceLayer(2)) chain.append(layers.CouplingLayer(2, swap=i % 2 == 0)) return layers.SequentialFlow(chain)
def build_model_tabular(args, dims, regularization_fns=[]): hidden_dims = tuple(map(int, args.dims.split("-"))) def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf chain = [build_cnf() for _ in range(args.num_blocks)] model = layers.SequentialFlow(chain) return model
def create_cnf_model(args, data_shape, regularization_fns): hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf chain = [layers.LogitTransform(alpha=args.cnf_alpha) ] if args.cnf_alpha > 0 else [layers.ZeroMeanTransform()] chain = chain + [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: chain.append(layers.MovingBatchNorm2d(data_shape[0])) model = layers.SequentialFlow(chain) return model
def build_model_tabular( dims=2, condition_dim=2, layer_type='concatsquash', nonlinearity='relu', residual=False, rademacher=False, train_T=True, solver='dopri5', time_length=0.1, divergence_fn='brute_force', # ["brute_force", "approximate"] hidden_dims=(32, 32), num_blocks=1, batch_norm=False, bn_lag=0, regularization_fns=None): def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), condition_dim=condition_dim, strides=None, conv=False, layer_type=layer_type, nonlinearity=nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=divergence_fn, residual=residual, rademacher=rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=time_length, train_T=train_T, regularization_fns=regularization_fns, solver=solver, ) return cnf chain = [build_cnf() for _ in range(num_blocks)] if batch_norm: bn_layers = [ layers.MovingBatchNorm1d(dims, bn_lag=bn_lag) for _ in range(num_blocks) ] bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=bn_lag)] for a, b in zip(chain, bn_layers): bn_chain.append(a) bn_chain.append(b) chain = bn_chain model = layers.SequentialFlow(chain) set_cnf_options(model, solver, rademacher, residual) return model
def build_model_tabular(args, dims, regularization_fns=None): hidden_dims = tuple(map(int, args.dims.split("-"))) def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=(dims,), strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, solver=args.solver, atol = args.atol, rtol = args.rtol, test_atol = args.test_atol, test_rtol = args.test_rtol, poly_num_sample=args.poly_num_sample, poly_order=args.poly_order, adjoint=args.adjoint, ) return cnf chain = [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: bn_layers = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag) for _ in range(args.num_blocks)] bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)] for a, b in zip(chain, bn_layers): bn_chain.append(a) bn_chain.append(b) chain = bn_chain model = layers.SequentialFlow(chain) set_cnf_options(args, model) return model
def build_model(): activation_fn = ACTIVATION_FNS[args.act] dims = [data_dim] + list(map(int, args.dims.split('-'))) + [data_dim] blocks = [] for _ in range(args.nblocks): blocks.append( layers.imBlock( build_nnet(dims, activation_fn), # ACTIVATION_FNS['zero'](), build_nnet(dims, activation_fn), n_dist=args.n_dist, n_power_series=args.n_power_series, exact_trace=False, brute_force=args.brute_force, n_samples=args.n_samples, n_exact_terms=args.n_exact_terms, neumann_grad=False, grad_in_forward=False, # toy data needn't save memory eps_forward=args.epsf)) model = layers.SequentialFlow(blocks).to(device) return model
for _ in range(args.nblocks): blocks.append( layers.iResBlock( build_nnet(dims, activation_fn), n_dist=args.n_dist, n_power_series=args.n_power_series, exact_trace=args.exact_trace, brute_force=args.brute_force, n_samples=args.n_samples, neumann_grad=False, grad_in_forward=False, ) ) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device) elif args.arch == 'realnvp': blocks = [] for _ in range(args.nblocks): blocks.append(layers.CouplingLayer(2, swap=False)) blocks.append(layers.CouplingLayer(2, swap=True)) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device) logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) #genGen = Generator(16) #genGen = Generator(128)
def create_model(args, data_shape, regularization_fns): hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) if args.multiscale: model = odenvp.ODENVP( (args.batch_size, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, nonlinearity=args.nonlinearity, alpha=args.alpha, cnf_kwargs={ "T": args.time_length, "train_T": args.train_T, "regularization_fns": regularization_fns }, ) elif args.parallel: model = multiscale_parallel.MultiscaleParallelCNF( (args.batch_size, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, alpha=args.alpha, time_length=args.time_length, ) else: if args.autoencode: def build_cnf(): autoencoder_diffeq = layers.AutoencoderDiffEqNet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.AutoencoderODEfunc( autoencoder_diffeq=autoencoder_diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, regularization_fns=regularization_fns, solver=args.solver, ) return cnf else: def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf chain = [layers.LogitTransform(alpha=args.alpha) ] if args.alpha > 0 else [layers.ZeroMeanTransform()] chain = chain + [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: chain.append(layers.MovingBatchNorm2d(data_shape[0])) model = layers.SequentialFlow(chain) return model
if args.actnorm: blocks.append(layers.ActNorm1d(2)) for _ in range(args.nblocks): blocks.append( layers.iResBlock( build_nnet(dims, activation_fn), n_dist=args.n_dist, n_power_series=args.n_power_series, exact_trace=args.exact_trace, brute_force=args.brute_force, n_samples=args.n_samples, neumann_grad=False, grad_in_forward=False, )) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device) elif args.arch == 'implicit': dims = [2] + list(map(int, args.dims.split('-'))) + [2] blocks = [] if args.actnorm: blocks.append(layers.ActNorm1d(2)) for _ in range(args.nblocks): blocks.append( layers.imBlock( build_nnet(dims, activation_fn), build_nnet(dims, activation_fn), n_dist=args.n_dist, n_power_series=args.n_power_series, exact_trace=args.exact_trace, brute_force=args.brute_force, n_samples=args.n_samples, neumann_grad=False,
def build_model(args, state_dict): # load dataset train_loader, test_loader, data_shape = get_dataset(args) hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) # neural net that parameterizes the velocity field if args.autoencode: def build_cnf(): autoencoder_diffeq = layers.AutoencoderDiffEqNet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.AutoencoderODEfunc( autoencoder_diffeq=autoencoder_diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, solver=args.solver, ) return cnf else: def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, solver=args.solver, ) return cnf chain = [layers.LogitTransform(alpha=args.alpha), build_cnf()] if args.batch_norm: chain.append(layers.MovingBatchNorm2d(data_shape[0])) model = layers.SequentialFlow(chain) if args.spectral_norm: add_spectral_norm(model) model.load_state_dict(state_dict) return model, test_loader.dataset
def build_augmented_model_tabular(args, dims, regularization_fns=None): """ The function used for creating conditional Continuous Normlizing Flow with augmented neural ODE Parameters: args: arguments used to create conditional CNF. Check args parser for details. dims: dimension of the input. Currently only allow 1-d input. regularization_fns: regularizations applied to the ODE function Returns: a ctfp model based on augmened neural ode """ hidden_dims = tuple(map(int, args.dims.split(","))) if args.aug_hidden_dims is not None: aug_hidden_dims = tuple(map(int, args.aug_hidden_dims.split(","))) else: aug_hidden_dims = None def build_cnf(): diffeq = layers.AugODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), effective_shape=args.effective_shape, strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, aug_dim=args.aug_dim, aug_mapping=args.aug_mapping, aug_hidden_dims=args.aug_hidden_dims, ) odefunc = layers.AugODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, effective_shape=args.effective_shape, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, rtol=args.rtol, atol=args.atol, ) return cnf chain = [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: bn_layers = [ layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag, effective_shape=args.effective_shape) for _ in range(args.num_blocks) ] bn_chain = [ layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag, effective_shape=args.effective_shape) ] for a, b in zip(chain, bn_layers): bn_chain.append(a) bn_chain.append(b) chain = bn_chain model = layers.SequentialFlow(chain) set_cnf_options(args, model) return model