Ejemplo n.º 1
0
    def __init__(
        self,
        initial_size,
        idims=(32,),
        nonlinearity="softplus",
        squeeze=True,
        init_layer=None,
        n_blocks=1,
        cnf_kwargs={},
    ):
        strides = tuple([1] + [1 for _ in idims])
        chain = []
        if init_layer is not None:
            chain.append(init_layer)

        def _make_odefunc(size):
            net = ODEnet(idims, size, strides, True, layer_type="concat", nonlinearity=nonlinearity)
            f = layers.ODEfunc(net)
            return f

        if squeeze:
            c, h, w = initial_size
            after_squeeze_size = c * 4, h // 2, w // 2
            pre = [layers.CNF(_make_odefunc(initial_size), **cnf_kwargs) for _ in range(n_blocks)]
            post = [layers.CNF(_make_odefunc(after_squeeze_size), **cnf_kwargs) for _ in range(n_blocks)]
            chain += pre + [layers.SqueezeLayer(2)] + post
        else:
            chain += [layers.CNF(_make_odefunc(initial_size), **cnf_kwargs) for _ in range(n_blocks)]

        super(StackedCNFLayers, self).__init__(chain)
Ejemplo n.º 2
0
    def __init__(
            self,
            initial_size,
            idims=(32, ),
            scales=4,
            init_layer=None,
            n_blocks=1,
            time_length=1.,
    ):
        strides = tuple([1] + [1 for _ in idims])
        chain = []
        if init_layer is not None:
            chain.append(init_layer)

        get_size = lambda s: (initial_size[0] * (4**s), initial_size[1] //
                              (2**s), initial_size[2] // (2**s))

        def _make_odefunc():
            nets = [
                ODEnet(idims,
                       get_size(scale),
                       strides,
                       True,
                       layer_type="concat",
                       num_squeeze=scale) for scale in range(scales)
            ]
            net = ParallelSumModules(nets)
            f = layers.ODEfunc(net)
            return f

        chain += [
            layers.CNF(_make_odefunc(), T=time_length) for _ in range(n_blocks)
        ]

        super(ParallelCNFLayers, self).__init__(chain)
Ejemplo n.º 3
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=(dims,),
         strides=None,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         solver=args.solver,
         atol = args.atol,
         rtol = args.rtol,
         test_atol = args.test_atol,
         test_rtol = args.test_rtol,
         poly_num_sample=args.poly_num_sample,
         poly_order=args.poly_order,
         adjoint=args.adjoint,
     )
     return cnf
Ejemplo n.º 4
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         regularization_fns=regularization_fns,
         solver=args.solver,
         num_steps=args.num_steps,
         adjoint=args.adjoint,
     )
     return cnf
Ejemplo n.º 5
0
def create_cnf(diffeq, regularization_fns=None):

    # inlined args default values
    solver = "dopri5"
    divergence_function = "approximate"  #"brute_force" # TODO?
    residual = False
    rademacher = False
    time_length = 1.0
    train_T = True  # TODO

    odefunc = layers.ODEfunc(
        diffeq=diffeq,
        divergence_fn=divergence_function,
        residual=residual,
        rademacher=rademacher,
    )
    cnf = layers.CNF(
        odefunc=odefunc,
        T=time_length,
        train_T=train_T,
        regularization_fns=regularization_fns,
        solver=solver,
    )

    set_cnf_options(cnf)

    return cnf
 def build_cnf():
     diffeq = layers.AugODEnet(
         hidden_dims=hidden_dims,
         input_shape=(dims, ),
         effective_shape=args.effective_shape,
         strides=None,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
         aug_dim=args.aug_dim,
         aug_mapping=args.aug_mapping,
         aug_hidden_dims=args.aug_hidden_dims,
     )
     odefunc = layers.AugODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
         effective_shape=args.effective_shape,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         regularization_fns=regularization_fns,
         solver=args.solver,
         rtol=args.rtol,
         atol=args.atol,
     )
     return cnf
Ejemplo n.º 7
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         solver=args.solver,
     )
     return cnf
Ejemplo n.º 8
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=(dims, ),
         strides=None,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf
Ejemplo n.º 9
0
 def build_cnf():
     autoencoder_diffeq = layers.AutoencoderDiffEqNet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.AutoencoderODEfunc(
         autoencoder_diffeq=autoencoder_diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf