def _resblock(initial_size, fc, idim=idim, first_resblock=False):
            if fc:
                nonloc_scope.swap = not nonloc_scope.swap
                return layers.CouplingBlock(
                    initial_size[0],
                    FCNet(
                        input_shape=initial_size,
                        idim=idim,
                        lipschitz_layer=_weight_layer(True),
                        nhidden=len(kernels.split('-')) - 1,
                        activation_fn=activation_fn,
                        preact=preact,
                        dropout=dropout,
                        coeff=None,
                        domains=None,
                        codomains=None,
                        n_iterations=None,
                        sn_atol=None,
                        sn_rtol=None,
                        learn_p=None,
                        div_in=2,
                    ),
                    swap=nonloc_scope.swap,
                )
            else:
                ks = list(map(int, kernels.split('-')))

                if init_layer is None:
                    _block = layers.ChannelCouplingBlock
                    _mask_type = 'channel'
                    div_in = 2
                    mult_out = 1
                else:
                    _block = layers.MaskedCouplingBlock
                    _mask_type = 'checkerboard'
                    div_in = 1
                    mult_out = 2

                nonloc_scope.swap = not nonloc_scope.swap
                _mask_type += '1' if nonloc_scope.swap else '0'

                nnet = []
                if not first_resblock and preact:
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))
                    nnet.append(ACT_FNS[activation_fn](False))
                nnet.append(_weight_layer(fc)(initial_size[0] // div_in, idim, ks[0], 1, ks[0] // 2))
                if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                nnet.append(ACT_FNS[activation_fn](True))
                for i, k in enumerate(ks[1:-1]):
                    nnet.append(_weight_layer(fc)(idim, idim, k, 1, k // 2))
                    if batchnorm: nnet.append(layers.MovingBatchNorm2d(idim))
                    nnet.append(ACT_FNS[activation_fn](True))
                if dropout: nnet.append(nn.Dropout2d(dropout, inplace=True))
                nnet.append(_weight_layer(fc)(idim, initial_size[0] * mult_out, ks[-1], 1, ks[-1] // 2))
                if batchnorm: nnet.append(layers.MovingBatchNorm2d(initial_size[0]))

                return _block(initial_size[0], nn.Sequential(*nnet), mask_type=_mask_type)
示例#2
0
                    n_dist=args.n_dist,
                    n_power_series=args.n_power_series,
                    exact_trace=args.exact_trace,
                    brute_force=args.brute_force,
                    n_samples=args.n_samples,
                    neumann_grad=False,
                    grad_in_forward=False,
                )
            )
            if args.actnorm: blocks.append(layers.ActNorm1d(2))
            if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
        model = layers.SequentialFlow(blocks).to(device)
    elif args.arch == 'realnvp':
        blocks = []
        for _ in range(args.nblocks):
            blocks.append(layers.CouplingBlock(2, swap=False))
            blocks.append(layers.CouplingBlock(2, swap=True))
            if args.actnorm: blocks.append(layers.ActNorm1d(2))
            if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
        model = layers.SequentialFlow(blocks).to(device)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    logpz_meter = utils.RunningAverageMeter(0.93)
    delta_logp_meter = utils.RunningAverageMeter(0.93)
                    n_dist=args.n_dist,
                    n_power_series=args.n_power_series,
                    exact_trace=args.exact_trace,
                    brute_force=args.brute_force,
                    n_samples=args.n_samples,
                    neumann_grad=False,
                    grad_in_forward=False,
                )
            )
            if args.actnorm: blocks.append(layers.ActNorm1d(2))
            if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
        model = layers.SequentialFlow(blocks).to(device)
    elif args.arch == 'realnvp':
        blocks = []
        for _ in range(args.nblocks):
            blocks.append(layers.CouplingBlock(2, build_nnet(dims, activation_fn), swap=False))
            blocks.append(layers.CouplingBlock(2, build_nnet(dims, activation_fn), swap=True))
            if args.actnorm: blocks.append(layers.ActNorm1d(2))
            if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
        model = layers.SequentialFlow(blocks).to(device)

    logger.info(model)
    logger.info("Number of trainable parameters: {}".format(count_parameters(model)))

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    time_meter = utils.RunningAverageMeter(0.93)
    loss_meter = utils.RunningAverageMeter(0.93)
    logpz_meter = utils.RunningAverageMeter(0.93)
    delta_logp_meter = utils.RunningAverageMeter(0.93)
示例#4
0
                n_power_series=args.n_power_series,
                exact_trace=args.exact_trace,
                brute_force=args.brute_force,
                n_samples=args.n_samples,
                neumann_grad=False,
                grad_in_forward=False,
            ))
        if args.actnorm: blocks.append(layers.ActNorm1d(2))
        if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
    model = layers.SequentialFlow(blocks).to(device)
elif args.arch == 'realnvp':
    blocks = []
    for _ in range(args.nblocks):
        blocks.append(
            layers.CouplingBlock(2,
                                 build_nnet_affine(dims, activation_fn),
                                 swap=False))
        blocks.append(
            layers.CouplingBlock(2,
                                 build_nnet_affine(dims, activation_fn),
                                 swap=True))
        if args.actnorm: blocks.append(layers.ActNorm1d(2))
        if args.batchnorm: blocks.append(layers.MovingBatchNorm1d(2))
    model = layers.SequentialFlow(blocks).to(device)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)


# checkerboard data
def checkerboard(batch_size):
    x1 = np.random.rand(batch_size) * 4 - 2