Example #1
0
 def _make_odefunc(size):
     net = ODEnet(idims,
                  size,
                  strides,
                  True,
                  layer_type="concat",
                  nonlinearity=nonlinearity)
     f = layers.ODEfunc(net)
     return f
Example #2
0
 def _make_odefunc():
     nets = [
         ODEnet(idims,
                get_size(scale),
                strides,
                True,
                layer_type="concat",
                num_squeeze=scale) for scale in range(scales)
     ]
     net = ParallelSumModules(nets)
     f = layers.ODEfunc(net)
     return f
Example #3
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         solver=args.solver,
     )
     return cnf
Example #4
0
def construct_amortized_odefunc(args, z_dim, amortization_type="bias"):

    hidden_dims = get_hidden_dims(args)

    if amortization_type == "bias":
        diffeq = AmortizedBiasODEnet(
            hidden_dims=hidden_dims,
            input_dim=z_dim,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
        )
    elif amortization_type == "hyper":
        diffeq = HyperODEnet(
            hidden_dims=hidden_dims,
            input_dim=z_dim,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
        )
    elif amortization_type == "lyper":
        diffeq = LyperODEnet(
            hidden_dims=hidden_dims,
            input_dim=z_dim,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
        )
    elif amortization_type == "low_rank":
        diffeq = AmortizedLowRankODEnet(
            hidden_dims=hidden_dims,
            input_dim=z_dim,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
            rank=args.rank,
        )
    odefunc = layers.ODEfunc(
        diffeq=diffeq,
        divergence_fn=args.divergence_fn,
        residual=args.residual,
        rademacher=args.rademacher,
    )
    return odefunc
Example #5
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=(dims, ),
         strides=None,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf