def _resblock(initial_size, fc, idim=idim, first_resblock=False, last_fc_block=False): if fc: return layers.Equivar_iResBlock( FCNet( in_type, out_type, group_action_type, input_shape=initial_size, idim=idim, lipschitz_layer=_lipschitz_layer(), nhidden=len(kernels.split('-')) - 1, coeff=coeff, domains=domains, codomains=codomains, n_iterations=n_lipschitz_iters, activation_fn=activation_fn, preact=preact, dropout=dropout, sn_atol=sn_atol, sn_rtol=sn_rtol, learn_p=learn_p, last_fc_block=last_fc_block, ), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, ) else: ks = list(map(int, kernels.split('-'))) if learn_p: _domains = [nn.Parameter(torch.tensor(0.)) for _ in range(len(ks))] _codomains = _domains[1:] + [_domains[0]] else: _domains = domains _codomains = codomains nnet = [] if not first_resblock and preact: if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) # nnet.append(activation_fn(in_type, inplace=True)) nnet.append( _lipschitz_layer()( in_type, out_type, group_action_type, ks[0], 1, ks[0] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[0], codomain=_codomains[0], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) # nnet.append(activation_fn(True)) nnet.append(activation_fn(nnet[-1].out_type, inplace=True)) for i, k in enumerate(ks[1:-1]): nnet.append( _lipschitz_layer()( nnet[-1].out_type, out_type, group_action_type, k, 1, k // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[i + 1], codomain=_codomains[i + 1], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(activation_fn(nnet[-1].out_type, inplace=True)) # nnet.append(activation_fn(True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append( _lipschitz_layer()( nnet[-1].out_type, in_type, group_action_type, ks[-1], 1, ks[-1] // 2, coeff=coeff, n_iterations=n_lipschitz_iters, domain=_domains[-1], codomain=_codomains[-1], atol=sn_atol, rtol=sn_rtol ) ) if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) return layers.Equivar_iResBlock( nn.Sequential(*nnet), n_power_series=n_power_series, n_dist=n_dist, n_samples=n_samples, n_exact_terms=n_exact_terms, neumann_grad=neumann_grad, grad_in_forward=grad_in_forward, )
def _resblock(initial_size, fc, idim=idim, first_resblock=False): if fc: nonloc_scope.swap = not nonloc_scope.swap return layers.CouplingBlock( initial_size[0], FCNet( input_shape=initial_size, idim=idim, lipschitz_layer=_weight_layer(True), nhidden=len(kernels.split('-')) - 1, activation_fn=activation_fn, preact=preact, dropout=dropout, coeff=None, domains=None, codomains=None, n_iterations=None, sn_atol=None, sn_rtol=None, learn_p=None, div_in=2, ), swap=nonloc_scope.swap, ) else: ks = list(map(int, kernels.split('-'))) if init_layer is None: _block = layers.ChannelCouplingBlock _mask_type = 'channel' div_in = 2 mult_out = 1 else: _block = layers.MaskedCouplingBlock _mask_type = 'checkerboard' div_in = 1 mult_out = 2 nonloc_scope.swap = not nonloc_scope.swap _mask_type += '1' if nonloc_scope.swap else '0' nnet = [] if not first_resblock and preact: if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) nnet.append(ACT_FNS[activation_fn](False)) nnet.append( _weight_layer(fc)(initial_size[0] // div_in, idim, ks[0], 1, ks[0] // 2)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) for i, k in enumerate(ks[1:-1]): nnet.append(_weight_layer(fc)(idim, idim, k, 1, k // 2)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim)) nnet.append(ACT_FNS[activation_fn](True)) if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True)) nnet.append( _weight_layer(fc)(idim, initial_size[0] * mult_out, ks[-1], 1, ks[-1] // 2)) if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0])) return _block(initial_size[0], nn.Sequential(*nnet), mask_type=_mask_type)