def __init__( self, initial_size, idims=(32,), nonlinearity="softplus", squeeze=True, init_layer=None, n_blocks=1, cnf_kwargs={}, ): strides = tuple([1] + [1 for _ in idims]) chain = [] if init_layer is not None: chain.append(init_layer) def _make_odefunc(size): net = ODEnet(idims, size, strides, True, layer_type="concat", nonlinearity=nonlinearity) f = layers.ODEfunc(net) return f if squeeze: c, h, w = initial_size after_squeeze_size = c * 4, h // 2, w // 2 pre = [layers.CNF(_make_odefunc(initial_size), **cnf_kwargs) for _ in range(n_blocks)] post = [layers.CNF(_make_odefunc(after_squeeze_size), **cnf_kwargs) for _ in range(n_blocks)] chain += pre + [layers.SqueezeLayer(2)] + post else: chain += [layers.CNF(_make_odefunc(initial_size), **cnf_kwargs) for _ in range(n_blocks)] super(StackedCNFLayers, self).__init__(chain)
def __init__( self, initial_size, idims=(32, ), scales=4, init_layer=None, n_blocks=1, time_length=1., ): strides = tuple([1] + [1 for _ in idims]) chain = [] if init_layer is not None: chain.append(init_layer) get_size = lambda s: (initial_size[0] * (4**s), initial_size[1] // (2**s), initial_size[2] // (2**s)) def _make_odefunc(): nets = [ ODEnet(idims, get_size(scale), strides, True, layer_type="concat", num_squeeze=scale) for scale in range(scales) ] net = ParallelSumModules(nets) f = layers.ODEfunc(net) return f chain += [ layers.CNF(_make_odefunc(), T=time_length) for _ in range(n_blocks) ] super(ParallelCNFLayers, self).__init__(chain)
def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=(dims,), strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, solver=args.solver, atol = args.atol, rtol = args.rtol, test_atol = args.test_atol, test_rtol = args.test_rtol, poly_num_sample=args.poly_num_sample, poly_order=args.poly_order, adjoint=args.adjoint, ) return cnf
def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, num_steps=args.num_steps, adjoint=args.adjoint, ) return cnf
def create_cnf(diffeq, regularization_fns=None): # inlined args default values solver = "dopri5" divergence_function = "approximate" #"brute_force" # TODO? residual = False rademacher = False time_length = 1.0 train_T = True # TODO odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=divergence_function, residual=residual, rademacher=rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=time_length, train_T=train_T, regularization_fns=regularization_fns, solver=solver, ) set_cnf_options(cnf) return cnf
def build_cnf(): diffeq = layers.AugODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), effective_shape=args.effective_shape, strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, aug_dim=args.aug_dim, aug_mapping=args.aug_mapping, aug_hidden_dims=args.aug_hidden_dims, ) odefunc = layers.AugODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, effective_shape=args.effective_shape, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, rtol=args.rtol, atol=args.atol, ) return cnf
def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, solver=args.solver, ) return cnf
def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf
def build_cnf(): autoencoder_diffeq = layers.AutoencoderDiffEqNet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.AutoencoderODEfunc( autoencoder_diffeq=autoencoder_diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, regularization_fns=regularization_fns, solver=args.solver, ) return cnf