예제 #1
0
    def __init__(
            self,
            initial_size,
            idims=(32, ),
            scales=4,
            init_layer=None,
            n_blocks=1,
            time_length=1.,
    ):
        strides = tuple([1] + [1 for _ in idims])
        chain = []
        if init_layer is not None:
            chain.append(init_layer)

        get_size = lambda s: (initial_size[0] * (4**s), initial_size[1] //
                              (2**s), initial_size[2] // (2**s))

        def _make_odefunc():
            nets = [
                ODEnet(idims,
                       get_size(scale),
                       strides,
                       True,
                       layer_type="concat",
                       num_squeeze=scale) for scale in range(scales)
            ]
            net = ParallelSumModules(nets)
            f = layers.ODEfunc(net)
            return f

        chain += [
            layers.CNF(_make_odefunc(), T=time_length) for _ in range(n_blocks)
        ]

        super(ParallelCNFLayers, self).__init__(chain)
예제 #2
0
    def __init__(
        self,
        initial_size,
        idims=(32, ),
        nonlinearity="softplus",
        squeeze=True,
        init_layer=None,
        n_blocks=1,
        cnf_kwargs={},
    ):
        strides = tuple([1] + [1 for _ in idims])
        chain = []
        if init_layer is not None:
            chain.append(init_layer)

        def _make_odefunc(size):
            net = ODEnet(idims,
                         size,
                         strides,
                         True,
                         layer_type="concat",
                         nonlinearity=nonlinearity)
            f = layers.ODEfunc(net)
            return f

        if squeeze:
            c, h, w = initial_size
            after_squeeze_size = c * 4, h // 2, w // 2
            pre = [
                layers.CNF(_make_odefunc(initial_size), **cnf_kwargs)
                for _ in range(n_blocks)
            ]
            post = [
                layers.CNF(_make_odefunc(after_squeeze_size), **cnf_kwargs)
                for _ in range(n_blocks)
            ]
            chain += pre + [layers.SqueezeLayer(2)] + post
        else:
            chain += [
                layers.CNF(_make_odefunc(initial_size), **cnf_kwargs)
                for _ in range(n_blocks)
            ]

        super(StackedCNFLayers, self).__init__(chain)
예제 #3
0
파일: viz_cnf.py 프로젝트: diadochos/ffjord
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         solver=args.solver,
     )
     return cnf
예제 #4
0
 def build_cnf():
     autoencoder_diffeq = layers.AutoencoderDiffEqNet(
         hidden_dims=hidden_dims,
         input_shape=data_shape,
         strides=strides,
         conv=args.conv,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.AutoencoderODEfunc(
         autoencoder_diffeq=autoencoder_diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf
예제 #5
0
 def build_cnf():
     diffeq = layers.ODEnet(
         hidden_dims=hidden_dims,
         input_shape=(dims, ),
         strides=None,
         conv=False,
         layer_type=args.layer_type,
         nonlinearity=args.nonlinearity,
     )
     odefunc = layers.ODEfunc(
         diffeq=diffeq,
         divergence_fn=args.divergence_fn,
         residual=args.residual,
         rademacher=args.rademacher,
     )
     cnf = layers.CNF(
         odefunc=odefunc,
         T=args.time_length,
         train_T=args.train_T,
         regularization_fns=regularization_fns,
         solver=args.solver,
     )
     return cnf