def create_model(args, data_shape, regularization_fns): hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) print("number of blocks: ", args.num_blocks) print("hidden_dims: ", hidden_dims) print("div_samples", args.div_samples) print("strides ", strides) print("squeeze_first ", args.squeeze_first) print("non linearity ", args.nonlinearity) print("layer_type ", args.layer_type) print("zero_last ", args.zero_last) model = odenvp.ODENVP( (args.batch_size, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, div_samples=args.div_samples, strides=strides, squeeze_first=args.squeeze_first, nonlinearity=args.nonlinearity, layer_type=args.layer_type, zero_last=args.zero_last, alpha=args.alpha, cnf_kwargs={ "T": args.time_length, "train_T": args.train_T, "regularization_fns": regularization_fns }, ) print("model created ...") return model
def create_model(args, data_shape): hidden_dims = tuple(map(int, args.dims.split(","))) model = odenvp.ODENVP( (BATCH_SIZE, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, nonlinearity=args.nonlinearity, alpha=args.alpha, cnf_kwargs={ "T": args.time_length, "train_T": args.train_T }, ) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) return model
def create_model(args, data_shape, regularization_fns): hidden_dims = tuple(map(int, args.dims.split(","))) strides = tuple(map(int, args.strides.split(","))) if args.multiscale: model = odenvp.ODENVP( (args.batch_size, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, nonlinearity=args.nonlinearity, alpha=args.alpha, cnf_kwargs={ "T": args.time_length, "train_T": args.train_T, "regularization_fns": regularization_fns }, ) elif args.parallel: model = multiscale_parallel.MultiscaleParallelCNF( (args.batch_size, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, alpha=args.alpha, time_length=args.time_length, ) else: if args.autoencode: def build_cnf(): autoencoder_diffeq = layers.AutoencoderDiffEqNet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.AutoencoderODEfunc( autoencoder_diffeq=autoencoder_diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, regularization_fns=regularization_fns, solver=args.solver, ) return cnf else: def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf chain = [layers.LogitTransform(alpha=args.alpha) ] if args.alpha > 0 else [layers.ZeroMeanTransform()] chain = chain + [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: chain.append(layers.MovingBatchNorm2d(data_shape[0])) model = layers.SequentialFlow(chain) return model