def build_model():
    activation_fn = ACTIVATION_FNS[args.act]
    dims = [data_dim] + list(map(int, args.dims.split('-'))) + [data_dim]
    blocks = []
    for _ in range(args.nblocks):
        blocks.append(
            layers.imBlock(
                build_nnet(dims, activation_fn),
                # ACTIVATION_FNS['zero'](),
                build_nnet(dims, activation_fn),
                n_dist=args.n_dist,
                n_power_series=args.n_power_series,
                exact_trace=False,
                brute_force=args.brute_force,
                n_samples=args.n_samples,
                n_exact_terms=args.n_exact_terms,
                neumann_grad=False,
                grad_in_forward=False,  # toy data needn't save memory
                eps_forward=args.epsf))
    model = layers.SequentialFlow(blocks).to(device)
    return model
Example #2
0
        def _resblock(initial_size, fc, idim=idim, first_resblock=True):
            if fc:
                return layers.imBlock(
                    FCNet(
                        input_shape=initial_size,
                        idim=idim,
                        lipschitz_layer=_lipschitz_layer(True),
                        nhidden=len(kernels.split('-')) - 1,
                        coeff=coeff,
                        domains=domains,
                        codomains=codomains,
                        n_iterations=n_lipschitz_iters,
                        activation_fn=activation_fn,
                        preact=preact,
                        dropout=dropout,
                        sn_atol=sn_atol,
                        sn_rtol=sn_rtol,
                        learn_p=learn_p,
                    ),
                    FCNet(
                        input_shape=initial_size,
                        idim=idim,
                        lipschitz_layer=_lipschitz_layer(True),
                        nhidden=len(kernels.split('-')) - 1,
                        coeff=coeff,
                        domains=domains,
                        codomains=codomains,
                        n_iterations=n_lipschitz_iters,
                        activation_fn=activation_fn,
                        preact=preact,
                        dropout=dropout,
                        sn_atol=sn_atol,
                        sn_rtol=sn_rtol,
                        learn_p=learn_p,
                    ),
                    n_power_series=n_power_series,
                    n_dist=n_dist,
                    n_samples=n_samples,
                    n_exact_terms=n_exact_terms,
                    neumann_grad=neumann_grad,
                    grad_in_forward=grad_in_forward,
                )
            else:

                def build_nnet():
                    ks = list(map(int, kernels.split('-')))
                    if learn_p:
                        _domains = [
                            nn.Parameter(torch.tensor(0.))
                            for _ in range(len(ks))
                        ]
                        _codomains = _domains[1:] + [_domains[0]]
                    else:
                        _domains = domains
                        _codomains = codomains
                    nnet = []
                    if not first_resblock and preact:
                        if batchnorm:
                            nnet.append(
                                layers.MovingBatchNorm2d(initial_size[0]))
                        nnet.append(ACT_FNS[activation_fn](False))
                    nnet.append(
                        _lipschitz_layer(fc)(initial_size[0],
                                             idim,
                                             ks[0],
                                             1,
                                             ks[0] // 2,
                                             coeff=coeff,
                                             n_iterations=n_lipschitz_iters,
                                             domain=_domains[0],
                                             codomain=_codomains[0],
                                             atol=sn_atol,
                                             rtol=sn_rtol))
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                    nnet.append(ACT_FNS[activation_fn](True))
                    for i, k in enumerate(ks[1:-1]):
                        nnet.append(
                            _lipschitz_layer(fc)(
                                idim,
                                idim,
                                k,
                                1,
                                k // 2,
                                coeff=coeff,
                                n_iterations=n_lipschitz_iters,
                                domain=_domains[i + 1],
                                codomain=_codomains[i + 1],
                                atol=sn_atol,
                                rtol=sn_rtol))
                        if batchnorm:
                            nnet.append(layers.MovingBatchNorm2d(idim))
                        nnet.append(ACT_FNS[activation_fn](True))
                    if dropout:
                        nnet.append(nn.Dropout2d(dropout, inplace=True))
                    nnet.append(
                        _lipschitz_layer(fc)(idim,
                                             initial_size[0],
                                             ks[-1],
                                             1,
                                             ks[-1] // 2,
                                             coeff=coeff,
                                             n_iterations=n_lipschitz_iters,
                                             domain=_domains[-1],
                                             codomain=_codomains[-1],
                                             atol=sn_atol,
                                             rtol=sn_rtol))
                    if batchnorm:
                        nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
                    return nn.Sequential(*nnet)

                return layers.imBlock(
                    build_nnet(),
                    build_nnet(),
                    n_power_series=n_power_series,
                    n_dist=n_dist,
                    n_samples=n_samples,
                    n_exact_terms=n_exact_terms,
                    neumann_grad=neumann_grad,
                    grad_in_forward=grad_in_forward,
                )
                ))
            if args.actnorm: blocks.append(layers.ActNorm1d(2))
            if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
        model = layers.SequentialFlow(blocks).to(device)
    elif args.arch == 'implicit':
        dims = [2] + list(map(int, args.dims.split('-'))) + [2]
        blocks = []
        if args.actnorm: blocks.append(layers.ActNorm1d(2))
        for _ in range(args.nblocks):
            blocks.append(
                layers.imBlock(
                    build_nnet(dims, activation_fn),
                    build_nnet(dims, activation_fn),
                    n_dist=args.n_dist,
                    n_power_series=args.n_power_series,
                    exact_trace=args.exact_trace,
                    brute_force=args.brute_force,
                    n_samples=args.n_samples,
                    neumann_grad=False,
                    grad_in_forward=False,  # toy data needn't save memory
                ))
        model = torch.nn.DataParallel(layers.SequentialFlow(blocks).to(device))
    elif args.arch == 'realnvp':
        blocks = []
        for _ in range(args.nblocks):
            blocks.append(layers.CouplingBlock(2, swap=False))
            blocks.append(layers.CouplingBlock(2, swap=True))
            if args.actnorm: blocks.append(layers.ActNorm1d(2))
            if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
        model = layers.SequentialFlow(blocks).to(device)
    def __init__(
        self,
        in_planes,
        hidden,
        planes,
        stride=1,
        n_lipschitz_iters=None,
        sn_atol=1e-3,
        sn_rtol=1e-3,
    ):
        super(BasicImplicitBlock, self).__init__()
        coeff = args.coeff
        self.initialized = False

        def build_net():
            layer = base_layers.get_conv2d
            nnet = []
            nnet.append(
                layer(
                    in_planes,
                    hidden,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=False,
                    coeff=coeff,
                    n_iterations=n_lipschitz_iters,
                    domain=2,
                    codomain=2,
                    atol=sn_atol,
                    rtol=sn_rtol,
                ))
            nnet.append(ACTIVATION_FNS['relu']())
            nnet.append(
                layer(
                    hidden,
                    in_planes,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                    bias=False,
                    coeff=coeff,
                    n_iterations=n_lipschitz_iters,
                    domain=2,
                    codomain=2,
                    atol=sn_atol,
                    rtol=sn_rtol,
                ))
            nnet.append(ACTIVATION_FNS['relu']())
            return nn.Sequential(*nnet)

        self.block = layers.imBlock(
            build_net(),
            build_net(),
        )
        self.downsample = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_planes,
                          self.expansion * planes,
                          kernel_size=1,
                          stride=stride,
                          bias=False),
                nn.BatchNorm2d(self.expansion * planes),
                ACTIVATION_FNS['relu'](),
            )