def build_model_tabular( dims=2, condition_dim=2, layer_type='concatsquash', nonlinearity='relu', residual=False, rademacher=False, train_T=True, solver='dopri5', time_length=0.1, divergence_fn='brute_force', # ["brute_force", "approximate"] hidden_dims=(32, 32), num_blocks=1, batch_norm=False, bn_lag=0, regularization_fns=None): def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), condition_dim=condition_dim, strides=None, conv=False, layer_type=layer_type, nonlinearity=nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=divergence_fn, residual=residual, rademacher=rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=time_length, train_T=train_T, regularization_fns=regularization_fns, solver=solver, ) return cnf chain = [build_cnf() for _ in range(num_blocks)] if batch_norm: bn_layers = [ layers.MovingBatchNorm1d(dims, bn_lag=bn_lag) for _ in range(num_blocks) ] bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=bn_lag)] for a, b in zip(chain, bn_layers): bn_chain.append(a) bn_chain.append(b) chain = bn_chain model = layers.SequentialFlow(chain) set_cnf_options(model, solver, rademacher, residual) return model
def build_model_tabular(args, dims, regularization_fns=None): hidden_dims = tuple(map(int, args.dims.split("-"))) def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=(dims,), strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, solver=args.solver, atol = args.atol, rtol = args.rtol, test_atol = args.test_atol, test_rtol = args.test_rtol, poly_num_sample=args.poly_num_sample, poly_order=args.poly_order, adjoint=args.adjoint, ) return cnf chain = [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: bn_layers = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag) for _ in range(args.num_blocks)] bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)] for a, b in zip(chain, bn_layers): bn_chain.append(a) bn_chain.append(b) chain = bn_chain model = layers.SequentialFlow(chain) set_cnf_options(args, model) return model
for _ in range(args.nblocks): blocks.append( layers.iResBlock( build_nnet(dims, activation_fn), n_dist=args.n_dist, n_power_series=args.n_power_series, exact_trace=args.exact_trace, brute_force=args.brute_force, n_samples=args.n_samples, neumann_grad=False, grad_in_forward=False, ) ) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device) elif args.arch == 'realnvp': blocks = [] for _ in range(args.nblocks): blocks.append(layers.CouplingLayer(2, swap=False)) blocks.append(layers.CouplingLayer(2, swap=True)) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device) logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) #genGen = Generator(16) #genGen = Generator(128)
for _ in range(args.nblocks): blocks.append( layers.iResBlock( build_nnet(dims, activation_fn), n_dist=args.n_dist, n_power_series=args.n_power_series, exact_trace=args.exact_trace, brute_force=args.brute_force, n_samples=args.n_samples, neumann_grad=False, grad_in_forward=False, )) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device) elif args.arch == 'realnvp': blocks = [] for _ in range(args.nblocks): blocks.append(layers.CouplingLayer(2, swap=False)) blocks.append(layers.CouplingLayer(2, swap=True)) if args.actnorm: blocks.append(layers.ActNorm1d(2)) if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2)) model = layers.SequentialFlow(blocks).to(device) logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model)))
def build_augmented_model_tabular(args, dims, regularization_fns=None): """ The function used for creating conditional Continuous Normlizing Flow with augmented neural ODE Parameters: args: arguments used to create conditional CNF. Check args parser for details. dims: dimension of the input. Currently only allow 1-d input. regularization_fns: regularizations applied to the ODE function Returns: a ctfp model based on augmened neural ode """ hidden_dims = tuple(map(int, args.dims.split(","))) if args.aug_hidden_dims is not None: aug_hidden_dims = tuple(map(int, args.aug_hidden_dims.split(","))) else: aug_hidden_dims = None def build_cnf(): diffeq = layers.AugODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), effective_shape=args.effective_shape, strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, aug_dim=args.aug_dim, aug_mapping=args.aug_mapping, aug_hidden_dims=args.aug_hidden_dims, ) odefunc = layers.AugODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, effective_shape=args.effective_shape, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, rtol=args.rtol, atol=args.atol, ) return cnf chain = [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: bn_layers = [ layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag, effective_shape=args.effective_shape) for _ in range(args.num_blocks) ] bn_chain = [ layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag, effective_shape=args.effective_shape) ] for a, b in zip(chain, bn_layers): bn_chain.append(a) bn_chain.append(b) chain = bn_chain model = layers.SequentialFlow(chain) set_cnf_options(args, model) return model