def __init__( self, initial_size, idims=(32, ), scales=4, init_layer=None, n_blocks=1, time_length=1., ): strides = tuple([1] + [1 for _ in idims]) chain = [] if init_layer is not None: chain.append(init_layer) get_size = lambda s: (initial_size[0] * (4**s), initial_size[1] // (2**s), initial_size[2] // (2**s)) def _make_odefunc(): nets = [ ODEnet(idims, get_size(scale), strides, True, layer_type="concat", num_squeeze=scale) for scale in range(scales) ] net = ParallelSumModules(nets) f = layers.ODEfunc(net) return f chain += [ layers.CNF(_make_odefunc(), T=time_length) for _ in range(n_blocks) ] super(ParallelCNFLayers, self).__init__(chain)
def __init__( self, initial_size, idims=(32, ), nonlinearity="softplus", squeeze=True, init_layer=None, n_blocks=1, cnf_kwargs={}, ): strides = tuple([1] + [1 for _ in idims]) chain = [] if init_layer is not None: chain.append(init_layer) def _make_odefunc(size): net = ODEnet(idims, size, strides, True, layer_type="concat", nonlinearity=nonlinearity) f = layers.ODEfunc(net) return f if squeeze: c, h, w = initial_size after_squeeze_size = c * 4, h // 2, w // 2 pre = [ layers.CNF(_make_odefunc(initial_size), **cnf_kwargs) for _ in range(n_blocks) ] post = [ layers.CNF(_make_odefunc(after_squeeze_size), **cnf_kwargs) for _ in range(n_blocks) ] chain += pre + [layers.SqueezeLayer(2)] + post else: chain += [ layers.CNF(_make_odefunc(initial_size), **cnf_kwargs) for _ in range(n_blocks) ] super(StackedCNFLayers, self).__init__(chain)
def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, solver=args.solver, ) return cnf
def build_cnf(): autoencoder_diffeq = layers.AutoencoderDiffEqNet( hidden_dims=hidden_dims, input_shape=data_shape, strides=strides, conv=args.conv, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.AutoencoderODEfunc( autoencoder_diffeq=autoencoder_diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, regularization_fns=regularization_fns, solver=args.solver, ) return cnf
def build_cnf(): diffeq = layers.ODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, ) odefunc = layers.ODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, ) return cnf