Ejemplo n.º 1
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         regularization_fns=regularization_fns,
         solver=args.solver,
         num_steps=args.num_steps,
         adjoint=args.adjoint,
     )
     return cnf
Ejemplo n.º 2
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=(dims,),
         strides=None,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         solver=args.solver,
         atol = args.atol,
         rtol = args.rtol,
         test_atol = args.test_atol,
         test_rtol = args.test_rtol,
         poly_num_sample=args.poly_num_sample,
         poly_order=args.poly_order,
         adjoint=args.adjoint,
     )
     return cnf
Ejemplo n.º 3
0
def create_cnf(diffeq, regularization_fns=None):

    # inlined args default values
    solver = "dopri5"
    divergence_function = "approximate"  #"brute_force" # TODO?
    residual = False
    rademacher = False
    time_length = 1.0
    train_T = True  # TODO

    odefunc = layers.ODEfunc(
        diffeq=diffeq,
        divergence_fn=divergence_function,
        residual=residual,
        rademacher=rademacher,
    )
    cnf = layers.CNF(
        odefunc=odefunc,
        T=time_length,
        train_T=train_T,
        regularization_fns=regularization_fns,
        solver=solver,
    )

    set_cnf_options(cnf)

    return cnf
Ejemplo n.º 4
0
 def _make_odefunc(size):
     net = ODEnet(idims,
                  size,
                  strides,
                  True,
                  layer_type="concat",
                  nonlinearity=nonlinearity)
     f = layers.ODEfunc(net)
     return f
Ejemplo n.º 5
0
 def _make_odefunc(size):
     net = ODEnet(idims,
                  size,
                  strides,
                  True,
                  layer_type=layer_type,
                  nonlinearity=nonlinearity,
                  zero_last_weight=zero_last)
     f = layers.ODEfunc(net, div_samples=div_samples)
     return f
Ejemplo n.º 6
0
 def _make_odefunc():
     nets = [
         ODEnet(idims,
                get_size(scale),
                strides,
                True,
                layer_type="concat",
                num_squeeze=scale) for scale in range(scales)
     ]
     net = ParallelSumModules(nets)
     f = layers.ODEfunc(net)
     return f
Ejemplo n.º 7
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         solver=args.solver,
     )
     return cnf
Ejemplo n.º 8
0
def construct_amortized_odefunc(args, z_dim, amortization_type="bias"):

    hidden_dims = get_hidden_dims(args)

    if amortization_type == "bias":
        diffeq = AmortizedBiasODEnet(
            hidden_dims=hidden_dims,
            input_dim=z_dim,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
        )
    elif amortization_type == "hyper":
        diffeq = HyperODEnet(
            hidden_dims=hidden_dims,
            input_dim=z_dim,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
        )
    elif amortization_type == "lyper":
        diffeq = LyperODEnet(
            hidden_dims=hidden_dims,
            input_dim=z_dim,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
        )
    elif amortization_type == "low_rank":
        diffeq = AmortizedLowRankODEnet(
            hidden_dims=hidden_dims,
            input_dim=z_dim,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
            rank=args.rank,
        )
    odefunc = layers.ODEfunc(
        diffeq=diffeq,
        divergence_fn=args.divergence_fn,
        residual=args.residual,
        rademacher=args.rademacher,
    )
    return odefunc
Ejemplo n.º 9
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=(dims, ),
         strides=None,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf