Exemplo n.º 1
0
 def _resblock(initial_size, fc, idim=idim, first_resblock=False):
     if fc:
         return layers.iResBlock(
             FCNet(
                 input_shape=initial_size,
                 idim=idim,
                 lipschitz_layer=_lipschitz_layer(True),
                 nhidden=len(kernels.split('-')) - 1,
                 coeff=coeff,
                 domains=domains,
                 codomains=codomains,
                 n_iterations=n_lipschitz_iters,
                 activation_fn=activation_fn,
                 preact=preact,
                 dropout=dropout,
                 sn_atol=sn_atol,
                 sn_rtol=sn_rtol,
                 learn_p=learn_p,
             ),
             n_power_series=n_power_series,
             n_dist=n_dist,
             n_samples=n_samples,
             n_exact_terms=n_exact_terms,
             neumann_grad=neumann_grad,
             grad_in_forward=grad_in_forward,
         )
     else:
         ks = list(map(int, kernels.split('-')))
         if learn_p:
             _domains = [
                 nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))
             ]
             _codomains = _domains[1:] + [_domains[0]]
         else:
             _domains = domains
             _codomains = codomains
         nnet = []
         if not first_resblock and preact:
             if batchnorm:
                 nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
             nnet.append(ACT_FNS[activation_fn](False))
         nnet.append(
             _lipschitz_layer(fc)(initial_size[0],
                                  idim,
                                  ks[0],
                                  1,
                                  ks[0] // 2,
                                  coeff=coeff,
                                  n_iterations=n_lipschitz_iters,
                                  domain=_domains[0],
                                  codomain=_codomains[0],
                                  atol=sn_atol,
                                  rtol=sn_rtol))
         if batchnorm:
             nnet.append(layers.MovingBatchNorm2d(idim))
         nnet.append(ACT_FNS[activation_fn](True))
         for i, k in enumerate(ks[1:-1]):
             nnet.append(
                 _lipschitz_layer(fc)(idim,
                                      idim,
                                      k,
                                      1,
                                      k // 2,
                                      coeff=coeff,
                                      n_iterations=n_lipschitz_iters,
                                      domain=_domains[i + 1],
                                      codomain=_codomains[i + 1],
                                      atol=sn_atol,
                                      rtol=sn_rtol))
             if batchnorm:
                 nnet.append(layers.MovingBatchNorm2d(idim))
             nnet.append(ACT_FNS[activation_fn](True))
         if dropout:
             nnet.append(nn.Dropout2d(dropout, inplace=True))
         nnet.append(
             _lipschitz_layer(fc)(idim,
                                  initial_size[0],
                                  ks[-1],
                                  1,
                                  ks[-1] // 2,
                                  coeff=coeff,
                                  n_iterations=n_lipschitz_iters,
                                  domain=_domains[-1],
                                  codomain=_codomains[-1],
                                  atol=sn_atol,
                                  rtol=sn_rtol))
         if batchnorm:
             nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
         return layers.iResBlock(
             nn.Sequential(*nnet),
             n_power_series=n_power_series,
             n_dist=n_dist,
             n_samples=n_samples,
             n_exact_terms=n_exact_terms,
             neumann_grad=neumann_grad,
             grad_in_forward=grad_in_forward,
         )
        def _resblock(initial_size, fc, idim=idim, first_resblock=False, densenet=densenet, densenet_depth=densenet_depth, densenet_growth=densenet_growth, fc_densenet_growth=fc_densenet_growth, learnable_concat=learnable_concat, lip_coeff=lip_coeff):
            if fc:
                return layers.iResBlock(
                    FCNet(
                        input_shape=initial_size,
                        idim=idim,
                        lipschitz_layer=_lipschitz_layer(True),
                        nhidden=len(kernels.split('-')) - 1,
                        coeff=coeff,
                        domains=domains,
                        codomains=codomains,
                        n_iterations=n_lipschitz_iters,
                        activation_fn=activation_fn,
                        preact=preact,
                        dropout=dropout,
                        sn_atol=sn_atol,
                        sn_rtol=sn_rtol,
                        learn_p=learn_p,
                        densenet=densenet,
                        densenet_depth=densenet_depth,
                        densenet_growth=densenet_growth,
                        fc_densenet_growth=fc_densenet_growth,
                        learnable_concat=learnable_concat,
                        lip_coeff=lip_coeff,
                    ),
                    n_power_series=n_power_series,
                    n_dist=n_dist,
                    n_samples=n_samples,
                    n_exact_terms=n_exact_terms,
                    neumann_grad=neumann_grad,
                    grad_in_forward=grad_in_forward,
                )
            else:
                if densenet:
                    ks = list(map(int, kernels.split('-')))
                    if learn_p:
                        _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))]
                        _codomains = _domains[1:] + [_domains[0]]
                    else:
                        _domains = domains
                        _codomains = codomains

                    # Initializing nnet as empty list
                    nnet = []

                    total_in_channels = initial_size[0]
                    for i in range(densenet_depth):
                        part_net = []

                        # Change growth size for CLipSwish:
                        if activation_fn == 'CLipSwish':
                            output_channels = densenet_growth // 2
                        else:
                            output_channels = densenet_growth

                        part_net.append(
                            _lipschitz_layer(fc)(
                                total_in_channels, output_channels, 3, 1, padding=1, coeff=coeff,
                                n_iterations=n_lipschitz_iters,
                                domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol
                            )
                        )

                        if batchnorm: part_net.append(layers.MovingBatchNorm2d(densenet_growth))
                        part_net.append(ACT_FNS[activation_fn](False))

                        nnet.append(
                            layers.LipschitzDenseLayer(
                                nn.Sequential(*part_net),
                                learnable_concat,
                                lip_coeff
                            )
                        )

                        total_in_channels += densenet_growth

                    # Last layer 1x1
                    nnet.append(
                        _lipschitz_layer(fc)(
                            total_in_channels, initial_size[0], 1, 1, padding=0, coeff=coeff, n_iterations=n_lipschitz_iters,
                            domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol
                        )
                    )

                    return layers.iResBlock(
                        nn.Sequential(*nnet),
                        n_power_series=n_power_series,
                        n_dist=n_dist,
                        n_samples=n_samples,
                        n_exact_terms=n_exact_terms,
                        neumann_grad=neumann_grad,
                        grad_in_forward=grad_in_forward,
                    )

                # RESNET
                else:
                    ks = list(map(int, kernels.split('-')))
                    if learn_p:
                        _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))]
                        _codomains = _domains[1:] + [_domains[0]]
                    else:
                        _domains = domains
                        _codomains = codomains

                    # Initializing nnet as empty list
                    nnet = []

                    # Change sizes for CLipSwish:
                    if activation_fn == 'CLipSwish':
                        idim_out = idim // 2

                        if not first_resblock and preact:
                            in_size = initial_size[0] * 2
                        else:
                            in_size = initial_size[0]

                    else:
                        idim_out = idim
                        in_size = initial_size[0]

                    if not first_resblock and preact:
                        if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
                        nnet.append(ACT_FNS[activation_fn](False))

                    nnet.append(
                        _lipschitz_layer(fc)(
                            in_size, idim_out, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                            domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol
                        )
                    )
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                    nnet.append(ACT_FNS[activation_fn](True))
                    for i, k in enumerate(ks[1:-1]):
                        nnet.append(
                            _lipschitz_layer(fc)(
                                idim, idim_out, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                                domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol
                            )
                        )
                        if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                        nnet.append(ACT_FNS[activation_fn](True))
                    if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
                    nnet.append(
                        _lipschitz_layer(fc)(
                            idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                            domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol
                        )
                    )
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
                    return layers.iResBlock(
                        nn.Sequential(*nnet),
                        n_power_series=n_power_series,
                        n_dist=n_dist,
                        n_samples=n_samples,
                        n_exact_terms=n_exact_terms,
                        neumann_grad=neumann_grad,
                        grad_in_forward=grad_in_forward,
                    )
if __name__ == '__main__':
    activation_fn = ACTIVATION_FNS[args.act]

    if args.arch == 'iresnet':
        dims = [2] + list(map(int, args.dims.split('-'))) + [2]

        blocks = []
        if args.actnorm: blocks.append(layers.ActNorm1d(2))

        for _ in range(args.nblocks):
            blocks.append(
                layers.iResBlock(
                    build_nnet(dims, activation_fn),
                    n_dist=args.n_dist,
                    n_power_series=args.n_power_series,
                    exact_trace=args.exact_trace,
                    brute_force=args.brute_force,
                    n_samples=args.n_samples,
                    neumann_grad=False,
                    grad_in_forward=False,
                )
            )
            if args.actnorm: blocks.append(layers.ActNorm1d(2))
            if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
        model = layers.SequentialFlow(blocks).to(device)
    elif args.arch == 'realnvp':
        blocks = []
        for _ in range(args.nblocks):
            blocks.append(layers.CouplingLayer(2, swap=False))
            blocks.append(layers.CouplingLayer(2, swap=True))
            if args.actnorm: blocks.append(layers.ActNorm1d(2))
            if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
Exemplo n.º 4
0
def get_iresblock(initial_size, fc, idim, first_resblock=False, n_power_series=5, kernels='3-1-3',
                  batchnorm=True, preact=False, activation_fn='relu', n_lipschitz_iters=2000, coeff=0.97,
                  sn_atol=1e-3, sn_rtol=1e-3, dropout=True, n_dist='geometric',
                  n_samples=1,n_exact_terms=0,neumann_grad=True, grad_in_forward=False):
    if fc:
        return layers.iResBlock(
            FCNet(
                input_shape=initial_size,
                idim=idim,
                lipschitz_layer=_lipschitz_layer(True),
                nhidden=len(kernels.split('-')) - 1,
                coeff=coeff,
                n_iterations=n_lipschitz_iters,
                activation_fn=activation_fn,
                preact=preact,
                dropout=dropout,
                sn_atol=sn_atol,
                sn_rtol=sn_rtol,
            ),
            n_power_series=n_power_series,
            n_dist=n_dist,
            n_samples=n_samples,
            n_exact_terms=n_exact_terms,
            neumann_grad=neumann_grad,
            grad_in_forward=grad_in_forward,
        )


    else:
        ks = list(map(int, kernels.split('-')))  # kernal size [3, 1, 3] by default
        nnet = []
        '''
        architecture:
        batchnorm 
        conv1d 
        batchnorm 
        conv1d 
        batchnorm 
        dropout 
        conv1d 
        batchnorm
    
        '''
        if not first_resblock and preact:
            if batchnorm: nnet.append(layers.MovingBatchNormconv1d(initial_size[0]))
            nnet.append(ACT_FNS[activation_fn](False))
        nnet.append(
            _lipschitz_layer(fc)(
                initial_size[0], idim, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
               atol=sn_atol, rtol=sn_rtol
            )
        )
        if batchnorm: nnet.append(layers.MovingBatchNormconv1d(idim))
        nnet.append(ACT_FNS[activation_fn](True))
        for i, k in enumerate(ks[1:-1]):
            nnet.append(
                _lipschitz_layer(fc)(
                    idim, idim, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                    atol=sn_atol, rtol=sn_rtol
                )
            )
            if batchnorm: nnet.append(layers.MovingBatchNormconv1d(idim))
            nnet.append(ACT_FNS[activation_fn](True))
        if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
        nnet.append(
            _lipschitz_layer(fc)(
                idim, initial_size[0], ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                 atol=sn_atol, rtol=sn_rtol
            )
        )
        if batchnorm: nnet.append(layers.MovingBatchNormconv1d(initial_size[0]))
        return layers.iResBlock(
            nn.Sequential(*nnet),
            n_power_series=n_power_series,
            n_dist=n_dist,
            n_samples=n_samples,
            n_exact_terms=n_exact_terms,
            neumann_grad=neumann_grad,
            grad_in_forward=grad_in_forward,
        )