def _resblock(initial_size, fc, idim=idim, first_resblock=False,
               last_fc_block=False):
     if fc:
         return layers.Equivar_iResBlock(
             FCNet(
                 in_type,
                 out_type,
                 group_action_type,
                 input_shape=initial_size,
                 idim=idim,
                 lipschitz_layer=_lipschitz_layer(),
                 nhidden=len(kernels.split('-')) - 1,
                 coeff=coeff,
                 domains=domains,
                 codomains=codomains,
                 n_iterations=n_lipschitz_iters,
                 activation_fn=activation_fn,
                 preact=preact,
                 dropout=dropout,
                 sn_atol=sn_atol,
                 sn_rtol=sn_rtol,
                 learn_p=learn_p,
                 last_fc_block=last_fc_block,
             ),
             n_power_series=n_power_series,
             n_dist=n_dist,
             n_samples=n_samples,
             n_exact_terms=n_exact_terms,
             neumann_grad=neumann_grad,
             grad_in_forward=grad_in_forward,
         )
     else:
         ks = list(map(int, kernels.split('-')))
         if learn_p:
             _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))]
             _codomains = _domains[1:] + [_domains[0]]
         else:
             _domains = domains
             _codomains = codomains
         nnet = []
         if not first_resblock and preact:
             if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
             # nnet.append(activation_fn(in_type, inplace=True))
         nnet.append(
             _lipschitz_layer()(
                 in_type, out_type, group_action_type,
                 ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                 domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol
             )
         )
         if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
         # nnet.append(activation_fn(True))
         nnet.append(activation_fn(nnet[-1].out_type, inplace=True))
         for i, k in enumerate(ks[1:-1]):
             nnet.append(
                 _lipschitz_layer()(
                     nnet[-1].out_type, out_type, group_action_type,
                     k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                     domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol
                 )
             )
             if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
             nnet.append(activation_fn(nnet[-1].out_type, inplace=True))
             # nnet.append(activation_fn(True))
         if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
         nnet.append(
             _lipschitz_layer()(
                 nnet[-1].out_type, in_type, group_action_type,
                 ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters,
                 domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol
             )
         )
         if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
         return layers.Equivar_iResBlock(
             nn.Sequential(*nnet),
             n_power_series=n_power_series,
             n_dist=n_dist,
             n_samples=n_samples,
             n_exact_terms=n_exact_terms,
             neumann_grad=neumann_grad,
             grad_in_forward=grad_in_forward,
         )
Ejemplo n.º 2
0
        def _resblock(initial_size, fc, idim=idim, first_resblock=False):
            if fc:
                nonloc_scope.swap = not nonloc_scope.swap
                return layers.CouplingBlock(
                    initial_size[0],
                    FCNet(
                        input_shape=initial_size,
                        idim=idim,
                        lipschitz_layer=_weight_layer(True),
                        nhidden=len(kernels.split('-')) - 1,
                        activation_fn=activation_fn,
                        preact=preact,
                        dropout=dropout,
                        coeff=None,
                        domains=None,
                        codomains=None,
                        n_iterations=None,
                        sn_atol=None,
                        sn_rtol=None,
                        learn_p=None,
                        div_in=2,
                    ),
                    swap=nonloc_scope.swap,
                )
            else:
                ks = list(map(int, kernels.split('-')))

                if init_layer is None:
                    _block = layers.ChannelCouplingBlock
                    _mask_type = 'channel'
                    div_in = 2
                    mult_out = 1
                else:
                    _block = layers.MaskedCouplingBlock
                    _mask_type = 'checkerboard'
                    div_in = 1
                    mult_out = 2

                nonloc_scope.swap = not nonloc_scope.swap
                _mask_type += '1' if nonloc_scope.swap else '0'

                nnet = []
                if not first_resblock and preact:
                    if batchnorm:
                        nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
                    nnet.append(ACT_FNS[activation_fn](False))
                nnet.append(
                    _weight_layer(fc)(initial_size[0] // div_in, idim, ks[0],
                                      1, ks[0] // 2))
                if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                nnet.append(ACT_FNS[activation_fn](True))
                for i, k in enumerate(ks[1:-1]):
                    nnet.append(_weight_layer(fc)(idim, idim, k, 1, k // 2))
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                    nnet.append(ACT_FNS[activation_fn](True))
                if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
                nnet.append(
                    _weight_layer(fc)(idim, initial_size[0] * mult_out, ks[-1],
                                      1, ks[-1] // 2))
                if batchnorm:
                    nnet.append(layers.MovingBatchNorm2d(initial_size[0]))

                return _block(initial_size[0],
                              nn.Sequential(*nnet),
                              mask_type=_mask_type)