def _make_odefunc(size): net = ODEnet(idims, size, strides, True, layer_type="concat", nonlinearity=nonlinearity) f = layers.ODEfunc(net) return f
def _make_odefunc(): nets = [ ODEnet(idims, get_size(scale), strides, True, layer_type="concat", num_squeeze=scale) for scale in range(scales) ] net = ParallelSumModules(nets) f = layers.ODEfunc(net) return f
def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, solver=args.solver, ) return cnf
def construct_amortized_odefunc(args, z_dim, amortization_type="bias"): hidden_dims = get_hidden_dims(args) if amortization_type == "bias": diffeq = AmortizedBiasODEnet( hidden_dims=hidden_dims, input_dim=z_dim, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) elif amortization_type == "hyper": diffeq = HyperODEnet( hidden_dims=hidden_dims, input_dim=z_dim, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) elif amortization_type == "lyper": diffeq = LyperODEnet( hidden_dims=hidden_dims, input_dim=z_dim, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) elif amortization_type == "low_rank": diffeq = AmortizedLowRankODEnet( hidden_dims=hidden_dims, input_dim=z_dim, layer_type=args.layer_type, nonlinearity=args.nonlinearity, rank=args.rank, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) return odefunc
def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf