def _resblock(initial_size, fc, idim=idim, first_resblock=False):
            if fc:
                nonloc_scope.swap = not nonloc_scope.swap
                return layers.CouplingBlock(
                    initial_size[0],
                    FCNet(
                        input_shape=initial_size,
                        idim=idim,
                        lipschitz_layer=_weight_layer(True),
                        nhidden=len(kernels.split('-')) - 1,
                        activation_fn=activation_fn,
                        preact=preact,
                        dropout=dropout,
                        coeff=None,
                        domains=None,
                        codomains=None,
                        n_iterations=None,
                        sn_atol=None,
                        sn_rtol=None,
                        learn_p=None,
                        div_in=2,
                    ),
                    swap=nonloc_scope.swap,
                )
            else:
                ks = list(map(int, kernels.split('-')))

                if init_layer is None:
                    _block = layers.ChannelCouplingBlock
                    _mask_type = 'channel'
                    div_in = 2
                    mult_out = 1
                else:
                    _block = layers.MaskedCouplingBlock
                    _mask_type = 'checkerboard'
                    div_in = 1
                    mult_out = 2

                nonloc_scope.swap = not nonloc_scope.swap
                _mask_type += '1' if nonloc_scope.swap else '0'

                nnet = []
                if not first_resblock and preact:
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
                    nnet.append(ACT_FNS[activation_fn](False))
                nnet.append(_weight_layer(fc)(initial_size[0] // div_in, idim, ks[0], 1, ks[0] // 2))
                if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                nnet.append(ACT_FNS[activation_fn](True))
                for i, k in enumerate(ks[1:-1]):
                    nnet.append(_weight_layer(fc)(idim, idim, k, 1, k // 2))
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                    nnet.append(ACT_FNS[activation_fn](True))
                if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
                nnet.append(_weight_layer(fc)(idim, initial_size[0] * mult_out, ks[-1], 1, ks[-1] // 2))
                if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))

                return _block(initial_size[0], nn.Sequential(*nnet), mask_type=_mask_type)
Exemple #2
0
def create_cnf_model(args, data_shape, regularization_fns):
    hidden_dims = tuple(map(int, args.dims.split(",")))
    strides = tuple(map(int, args.strides.split(",")))

    def build_cnf():
        diffeq = layers.ODEnet(
            hidden_dims=hidden_dims,
            input_shape=data_shape,
            strides=strides,
            conv=args.conv,
            layer_type=args.layer_type,
            nonlinearity=args.nonlinearity,
        )
        odefunc = layers.ODEfunc(
            diffeq=diffeq,
            divergence_fn=args.divergence_fn,
            residual=args.residual,
            rademacher=args.rademacher,
        )
        cnf = layers.CNF(
            odefunc=odefunc,
            T=args.time_length,
            train_T=args.train_T,
            regularization_fns=regularization_fns,
            solver=args.solver,
        )
        return cnf

    chain = [layers.LogitTransform(alpha=args.cnf_alpha)
             ] if args.cnf_alpha > 0 else [layers.ZeroMeanTransform()]
    chain = chain + [build_cnf() for _ in range(args.num_blocks)]
    if args.batch_norm:
        chain.append(layers.MovingBatchNorm2d(data_shape[0]))
    model = layers.SequentialFlow(chain)
    return model
Exemple #3
0
 def _resblock(initial_size, fc, idim=idim, first_resblock=False):
     if fc:
         return layers.iResBlock(
             FCNet(
                 input_shape=initial_size,
                 idim=idim,
                 lipschitz_layer=_lipschitz_layer(True),
                 nhidden=len(kernels.split('-')) - 1,
                 coeff=coeff,
                 domains=domains,
                 codomains=codomains,
                 n_iterations=n_lipschitz_iters,
                 activation_fn=activation_fn,
                 preact=preact,
                 dropout=dropout,
                 sn_atol=sn_atol,
                 sn_rtol=sn_rtol,
                 learn_p=learn_p,
             ),
             n_power_series=n_power_series,
             n_dist=n_dist,
             n_samples=n_samples,
             n_exact_terms=n_exact_terms,
             neumann_grad=neumann_grad,
             grad_in_forward=grad_in_forward,
         )
     else:
         ks = list(map(int, kernels.split('-')))
         if learn_p:
             _domains = [
                 nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))
             ]
             _codomains = _domains[1:] + [_domains[0]]
         else:
             _domains = domains
             _codomains = codomains
         nnet = []
         if not first_resblock and preact:
             if batchnorm:
                 nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
             nnet.append(ACT_FNS[activation_fn](False))
         nnet.append(
             _lipschitz_layer(fc)(initial_size[0],
                                  idim,
                                  ks[0],
                                  1,
                                  ks[0] // 2,
                                  coeff=coeff,
                                  n_iterations=n_lipschitz_iters,
                                  domain=_domains[0],
                                  codomain=_codomains[0],
                                  atol=sn_atol,
                                  rtol=sn_rtol))
         if batchnorm:
             nnet.append(layers.MovingBatchNorm2d(idim))
         nnet.append(ACT_FNS[activation_fn](True))
         for i, k in enumerate(ks[1:-1]):
             nnet.append(
                 _lipschitz_layer(fc)(idim,
                                      idim,
                                      k,
                                      1,
                                      k // 2,
                                      coeff=coeff,
                                      n_iterations=n_lipschitz_iters,
                                      domain=_domains[i + 1],
                                      codomain=_codomains[i + 1],
                                      atol=sn_atol,
                                      rtol=sn_rtol))
             if batchnorm:
                 nnet.append(layers.MovingBatchNorm2d(idim))
             nnet.append(ACT_FNS[activation_fn](True))
         if dropout:
             nnet.append(nn.Dropout2d(dropout, inplace=True))
         nnet.append(
             _lipschitz_layer(fc)(idim,
                                  initial_size[0],
                                  ks[-1],
                                  1,
                                  ks[-1] // 2,
                                  coeff=coeff,
                                  n_iterations=n_lipschitz_iters,
                                  domain=_domains[-1],
                                  codomain=_codomains[-1],
                                  atol=sn_atol,
                                  rtol=sn_rtol))
         if batchnorm:
             nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
         return layers.iResBlock(
             nn.Sequential(*nnet),
             n_power_series=n_power_series,
             n_dist=n_dist,
             n_samples=n_samples,
             n_exact_terms=n_exact_terms,
             neumann_grad=neumann_grad,
             grad_in_forward=grad_in_forward,
         )
Exemple #4
0
 def build_nnet():
     ks = list(map(int, kernels.split('-')))
     if learn_p:
         _domains = [
             nn.Parameter(torch.tensor(0.))
             for _ in range(len(ks))
         ]
         _codomains = _domains[1:] + [_domains[0]]
     else:
         _domains = domains
         _codomains = codomains
     nnet = []
     if not first_resblock and preact:
         if batchnorm:
             nnet.append(
                 layers.MovingBatchNorm2d(initial_size[0]))
         nnet.append(ACT_FNS[activation_fn](False))
     nnet.append(
         _lipschitz_layer(fc)(initial_size[0],
                              idim,
                              ks[0],
                              1,
                              ks[0] // 2,
                              coeff=coeff,
                              n_iterations=n_lipschitz_iters,
                              domain=_domains[0],
                              codomain=_codomains[0],
                              atol=sn_atol,
                              rtol=sn_rtol))
     if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
     nnet.append(ACT_FNS[activation_fn](True))
     for i, k in enumerate(ks[1:-1]):
         nnet.append(
             _lipschitz_layer(fc)(
                 idim,
                 idim,
                 k,
                 1,
                 k // 2,
                 coeff=coeff,
                 n_iterations=n_lipschitz_iters,
                 domain=_domains[i + 1],
                 codomain=_codomains[i + 1],
                 atol=sn_atol,
                 rtol=sn_rtol))
         if batchnorm:
             nnet.append(layers.MovingBatchNorm2d(idim))
         nnet.append(ACT_FNS[activation_fn](True))
     if dropout:
         nnet.append(nn.Dropout2d(dropout, inplace=True))
     nnet.append(
         _lipschitz_layer(fc)(idim,
                              initial_size[0],
                              ks[-1],
                              1,
                              ks[-1] // 2,
                              coeff=coeff,
                              n_iterations=n_lipschitz_iters,
                              domain=_domains[-1],
                              codomain=_codomains[-1],
                              atol=sn_atol,
                              rtol=sn_rtol))
     if batchnorm:
         nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
     return nn.Sequential(*nnet)
Exemple #5
0
def create_model(args, data_shape, regularization_fns):
    hidden_dims = tuple(map(int, args.dims.split(",")))
    strides = tuple(map(int, args.strides.split(",")))

    if args.multiscale:
        model = odenvp.ODENVP(
            (args.batch_size, *data_shape),
            n_blocks=args.num_blocks,
            intermediate_dims=hidden_dims,
            nonlinearity=args.nonlinearity,
            alpha=args.alpha,
            cnf_kwargs={
                "T": args.time_length,
                "train_T": args.train_T,
                "regularization_fns": regularization_fns
            },
        )
    elif args.parallel:
        model = multiscale_parallel.MultiscaleParallelCNF(
            (args.batch_size, *data_shape),
            n_blocks=args.num_blocks,
            intermediate_dims=hidden_dims,
            alpha=args.alpha,
            time_length=args.time_length,
        )
    else:
        if args.autoencode:

            def build_cnf():
                autoencoder_diffeq = layers.AutoencoderDiffEqNet(
                    hidden_dims=hidden_dims,
                    input_shape=data_shape,
                    strides=strides,
                    conv=args.conv,
                    layer_type=args.layer_type,
                    nonlinearity=args.nonlinearity,
                )
                odefunc = layers.AutoencoderODEfunc(
                    autoencoder_diffeq=autoencoder_diffeq,
                    divergence_fn=args.divergence_fn,
                    residual=args.residual,
                    rademacher=args.rademacher,
                )
                cnf = layers.CNF(
                    odefunc=odefunc,
                    T=args.time_length,
                    regularization_fns=regularization_fns,
                    solver=args.solver,
                )
                return cnf
        else:

            def build_cnf():
                diffeq = layers.ODEnet(
                    hidden_dims=hidden_dims,
                    input_shape=data_shape,
                    strides=strides,
                    conv=args.conv,
                    layer_type=args.layer_type,
                    nonlinearity=args.nonlinearity,
                )
                odefunc = layers.ODEfunc(
                    diffeq=diffeq,
                    divergence_fn=args.divergence_fn,
                    residual=args.residual,
                    rademacher=args.rademacher,
                )
                cnf = layers.CNF(
                    odefunc=odefunc,
                    T=args.time_length,
                    train_T=args.train_T,
                    regularization_fns=regularization_fns,
                    solver=args.solver,
                )
                return cnf

        chain = [layers.LogitTransform(alpha=args.alpha)
                 ] if args.alpha > 0 else [layers.ZeroMeanTransform()]
        chain = chain + [build_cnf() for _ in range(args.num_blocks)]
        if args.batch_norm:
            chain.append(layers.MovingBatchNorm2d(data_shape[0]))
        model = layers.SequentialFlow(chain)
    return model
def build_model(args, state_dict):
    # load dataset
    train_loader, test_loader, data_shape = get_dataset(args)

    hidden_dims = tuple(map(int, args.dims.split(",")))
    strides = tuple(map(int, args.strides.split(",")))

    # neural net that parameterizes the velocity field
    if args.autoencode:

        def build_cnf():
            autoencoder_diffeq = layers.AutoencoderDiffEqNet(
                hidden_dims=hidden_dims,
                input_shape=data_shape,
                strides=strides,
                conv=args.conv,
                layer_type=args.layer_type,
                nonlinearity=args.nonlinearity,
            )
            odefunc = layers.AutoencoderODEfunc(
                autoencoder_diffeq=autoencoder_diffeq,
                divergence_fn=args.divergence_fn,
                residual=args.residual,
                rademacher=args.rademacher,
            )
            cnf = layers.CNF(
                odefunc=odefunc,
                T=args.time_length,
                solver=args.solver,
            )
            return cnf
    else:

        def build_cnf():
            diffeq = layers.ODEnet(
                hidden_dims=hidden_dims,
                input_shape=data_shape,
                strides=strides,
                conv=args.conv,
                layer_type=args.layer_type,
                nonlinearity=args.nonlinearity,
            )
            odefunc = layers.ODEfunc(
                diffeq=diffeq,
                divergence_fn=args.divergence_fn,
                residual=args.residual,
                rademacher=args.rademacher,
            )
            cnf = layers.CNF(
                odefunc=odefunc,
                T=args.time_length,
                solver=args.solver,
            )
            return cnf

    chain = [layers.LogitTransform(alpha=args.alpha), build_cnf()]
    if args.batch_norm:
        chain.append(layers.MovingBatchNorm2d(data_shape[0]))
    model = layers.SequentialFlow(chain)

    if args.spectral_norm:
        add_spectral_norm(model)

    model.load_state_dict(state_dict)

    return model, test_loader.dataset
        def _resblock(initial_size, fc, idim=idim, first_resblock=False, densenet=densenet, densenet_depth=densenet_depth, densenet_growth=densenet_growth, fc_densenet_growth=fc_densenet_growth, learnable_concat=learnable_concat, lip_coeff=lip_coeff):
            if fc:
                return layers.iResBlock(
                    FCNet(
                        input_shape=initial_size,
                        idim=idim,
                        lipschitz_layer=_lipschitz_layer(True),
                        nhidden=len(kernels.split('-')) - 1,
                        coeff=coeff,
                        domains=domains,
                        codomains=codomains,
                        n_iterations=n_lipschitz_iters,
                        activation_fn=activation_fn,
                        preact=preact,
                        dropout=dropout,
                        sn_atol=sn_atol,
                        sn_rtol=sn_rtol,
                        learn_p=learn_p,
                        densenet=densenet,
                        densenet_depth=densenet_depth,
                        densenet_growth=densenet_growth,
                        fc_densenet_growth=fc_densenet_growth,
                        learnable_concat=learnable_concat,
                        lip_coeff=lip_coeff,
                    ),
                    n_power_series=n_power_series,
                    n_dist=n_dist,
                    n_samples=n_samples,
                    n_exact_terms=n_exact_terms,
                    neumann_grad=neumann_grad,
                    grad_in_forward=grad_in_forward,
                )
            else:
                if densenet:
                    ks = list(map(int, kernels.split('-')))
                    if learn_p:
                        _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))]
                        _codomains = _domains[1:] + [_domains[0]]
                    else:
                        _domains = domains
                        _codomains = codomains

                    # Initializing nnet as empty list
                    nnet = []

                    total_in_channels = initial_size[0]
                    for i in range(densenet_depth):
                        part_net = []

                        # Change growth size for CLipSwish:
                        if activation_fn == 'CLipSwish':
                            output_channels = densenet_growth // 2
                        else:
                            output_channels = densenet_growth

                        part_net.append(
                            _lipschitz_layer(fc)(
                                total_in_channels, output_channels, 3, 1, padding=1, coeff=coeff,
                                n_iterations=n_lipschitz_iters,
                                domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol
                            )
                        )

                        if batchnorm: part_net.append(layers.MovingBatchNorm2d(densenet_growth))
                        part_net.append(ACT_FNS[activation_fn](False))

                        nnet.append(
                            layers.LipschitzDenseLayer(
                                nn.Sequential(*part_net),
                                learnable_concat,
                                lip_coeff
                            )
                        )

                        total_in_channels += densenet_growth

                    # Last layer 1x1
                    nnet.append(
                        _lipschitz_layer(fc)(
                            total_in_channels, initial_size[0], 1, 1, padding=0, coeff=coeff, n_iterations=n_lipschitz_iters,
                            domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol
                        )
                    )

                    return layers.iResBlock(
                        nn.Sequential(*nnet),
                        n_power_series=n_power_series,
                        n_dist=n_dist,
                        n_samples=n_samples,
                        n_exact_terms=n_exact_terms,
                        neumann_grad=neumann_grad,
                        grad_in_forward=grad_in_forward,
                    )

                # RESNET
                else:
                    ks = list(map(int, kernels.split('-')))
                    if learn_p:
                        _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))]
                        _codomains = _domains[1:] + [_domains[0]]
                    else:
                        _domains = domains
                        _codomains = codomains

                    # Initializing nnet as empty list
                    nnet = []

                    # Change sizes for CLipSwish:
                    if activation_fn == 'CLipSwish':
                        idim_out = idim // 2

                        if not first_resblock and preact:
                            in_size = initial_size[0] * 2
                        else:
                            in_size = initial_size[0]

                    else:
                        idim_out = idim
                        in_size = initial_size[0]

                    if not first_resblock and preact:
                        if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
                        nnet.append(ACT_FNS[activation_fn](False))

                    nnet.append(
                        _lipschitz_layer(fc)(
                            in_size, idim_out, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                            domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol
                        )
                    )
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                    nnet.append(ACT_FNS[activation_fn](True))
                    for i, k in enumerate(ks[1:-1]):
                        nnet.append(
                            _lipschitz_layer(fc)(
                                idim, idim_out, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                                domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol
                            )
                        )
                        if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                        nnet.append(ACT_FNS[activation_fn](True))
                    if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
                    nnet.append(
                        _lipschitz_layer(fc)(
                            idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                            domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol
                        )
                    )
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
                    return layers.iResBlock(
                        nn.Sequential(*nnet),
                        n_power_series=n_power_series,
                        n_dist=n_dist,
                        n_samples=n_samples,
                        n_exact_terms=n_exact_terms,
                        neumann_grad=neumann_grad,
                        grad_in_forward=grad_in_forward,
                    )