Пример #1
0
def construct_model():

    if args.nf:
        chain = []
        for i in range(args.depth):
            chain.append(layers.PlanarFlow(2))
        return layers.SequentialFlow(chain)
    else:
        chain = []
        for i in range(args.depth):
            if args.glow: chain.append(layers.BruteForceLayer(2))
            chain.append(layers.CouplingLayer(2, swap=i % 2 == 0))
        return layers.SequentialFlow(chain)
Пример #2
0
def build_model(input_dim):
    hidden_dims = tuple(map(int, args.dims.split("-")))
    chain = []
    for i in range(args.depth):
        if args.glow: chain.append(layers.BruteForceLayer(input_dim))
        chain.append(
            layers.MaskedCouplingLayer(input_dim,
                                       hidden_dims,
                                       'alternate',
                                       swap=i % 2 == 0))
        if args.batch_norm:
            chain.append(
                layers.MovingBatchNorm1d(input_dim, bn_lag=args.bn_lag))
    return layers.SequentialFlow(chain)
Пример #3
0
def build_model_tabular(args, dims, regularization_fns=None):

    hidden_dims = tuple(map(int, args.dims.split("-")))

    def build_cnf():
        diffeq = layers.ODEnet(
            hidden_dims=hidden_dims,
            input_shape=(dims, ),
            strides=None,
            conv=False,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
        )
        odefunc = layers.ODEfunc(
            diffeq=diffeq,
            divergence_fn=args.divergence_fn,
            residual=args.residual,
            rademacher=args.rademacher,
        )
        cnf = layers.CNF(
            odefunc=odefunc,
            T=args.time_length,
            train_T=args.train_T,
            regularization_fns=regularization_fns,
            solver=args.solver,
        )
        return cnf

    chain = [build_cnf() for _ in range(args.num_blocks)]
    if args.batch_norm:
        bn_layers = [
            layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)
            for _ in range(args.num_blocks)
        ]
        bn_chain = [layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag)]
        for a, b in zip(chain, bn_layers):
            bn_chain.append(a)
            bn_chain.append(b)
        chain = bn_chain
    model = layers.SequentialFlow(chain)

    set_cnf_options(args, model)

    return model
Пример #4
0
def create_model(args, data_shape, regularization_fns):
    hidden_dims = tuple(map(int, args.dims.split(",")))
    strides = tuple(map(int, args.strides.split(",")))

    if args.multiscale:
        model = odenvp.ODENVP(
            (args.batch_size, *data_shape),
            n_blocks=args.num_blocks,
            intermediate_dims=hidden_dims,
            nonlinearity=args.nonlinearity,
            alpha=args.alpha,
            cnf_kwargs={
                "T": args.time_length,
                "train_T": args.train_T,
                "regularization_fns": regularization_fns
            },
        )
    elif args.parallel:
        model = multiscale_parallel.MultiscaleParallelCNF(
            (args.batch_size, *data_shape),
            n_blocks=args.num_blocks,
            intermediate_dims=hidden_dims,
            alpha=args.alpha,
            time_length=args.time_length,
        )
    else:
        if args.autoencode:

            def build_cnf():
                autoencoder_diffeq = layers.AutoencoderDiffEqNet(
                    hidden_dims=hidden_dims,
                    input_shape=data_shape,
                    strides=strides,
                    conv=args.conv,
                    layer_type=args.layer_type,
                    nonlinearity=args.nonlinearity,
                )
                odefunc = layers.AutoencoderODEfunc(
                    autoencoder_diffeq=autoencoder_diffeq,
                    divergence_fn=args.divergence_fn,
                    residual=args.residual,
                    rademacher=args.rademacher,
                )
                cnf = layers.CNF(
                    odefunc=odefunc,
                    T=args.time_length,
                    regularization_fns=regularization_fns,
                    solver=args.solver,
                )
                return cnf
        else:

            def build_cnf():
                diffeq = layers.ODEnet(
                    hidden_dims=hidden_dims,
                    input_shape=data_shape,
                    strides=strides,
                    conv=args.conv,
                    layer_type=args.layer_type,
                    nonlinearity=args.nonlinearity,
                )
                odefunc = layers.ODEfunc(
                    diffeq=diffeq,
                    divergence_fn=args.divergence_fn,
                    residual=args.residual,
                    rademacher=args.rademacher,
                )
                cnf = layers.CNF(
                    odefunc=odefunc,
                    T=args.time_length,
                    train_T=args.train_T,
                    regularization_fns=regularization_fns,
                    solver=args.solver,
                )
                return cnf

        chain = [layers.LogitTransform(alpha=args.alpha)
                 ] if args.alpha > 0 else [layers.ZeroMeanTransform()]
        chain = chain + [build_cnf() for _ in range(args.num_blocks)]
        if args.batch_norm:
            chain.append(layers.MovingBatchNorm2d(data_shape[0]))
        model = layers.SequentialFlow(chain)
    return model
Пример #5
0
def build_model(args, state_dict):
    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    hidden_dims = tuple(map(int, args.dims.split(",")))
    strides = tuple(map(int, args.strides.split(",")))

    # neural net that parameterizes the velocity field
    if args.autoencode:

        def build_cnf():
            autoencoder_diffeq = layers.AutoencoderDiffEqNet(
                hidden_dims=hidden_dims,
                input_shape=data_shape,
                strides=strides,
                conv=args.conv,
                layer_type=args.layer_type,
                nonlinearity=args.nonlinearity,
            )
            odefunc = layers.AutoencoderODEfunc(
                autoencoder_diffeq=autoencoder_diffeq,
                divergence_fn=args.divergence_fn,
                residual=args.residual,
                rademacher=args.rademacher,
            )
            cnf = layers.CNF(
                odefunc=odefunc,
                T=args.time_length,
                solver=args.solver,
            )
            return cnf
    else:

        def build_cnf():
            diffeq = layers.ODEnet(
                hidden_dims=hidden_dims,
                input_shape=data_shape,
                strides=strides,
                conv=args.conv,
                layer_type=args.layer_type,
                nonlinearity=args.nonlinearity,
            )
            odefunc = layers.ODEfunc(
                diffeq=diffeq,
                divergence_fn=args.divergence_fn,
                residual=args.residual,
                rademacher=args.rademacher,
            )
            cnf = layers.CNF(
                odefunc=odefunc,
                T=args.time_length,
                solver=args.solver,
            )
            return cnf

    chain = [layers.LogitTransform(alpha=args.alpha), build_cnf()]
    if args.batch_norm:
        chain.append(layers.MovingBatchNorm2d(data_shape[0]))
    model = layers.SequentialFlow(chain)

    if args.spectral_norm:
        add_spectral_norm(model)

    model.load_state_dict(state_dict)

    return model, test_loader.dataset