コード例 #1
0
ファイル: dequantization.py プロジェクト: jxzhangjhu/NuX
 def default_flow(self):
   return nux.sequential(nux.Logit(scale=None),
                         nux.OneByOneConv(),
                         nux.reverse_flow(nux.CouplingLogitsticMixtureLogit(n_components=8,
                                                                            network_kwargs=self.network_kwargs,
                                                                            use_condition=True)),
                         nux.OneByOneConv(),
                         nux.reverse_flow(nux.CouplingLogitsticMixtureLogit(n_components=8,
                                                                            network_kwargs=self.network_kwargs,
                                                                            use_condition=True)),
                         nux.UnitGaussianPrior())
コード例 #2
0
 def default_decoder(self):
     # Generate positive values only
     return nux.sequential(
         nux.SoftplusInverse(), nux.OneByOneConv(),
         nux.LogisticMixtureLogit(n_components=4,
                                  network_kwargs=self.network_kwargs,
                                  reverse=False,
                                  use_condition=True), nux.OneByOneConv(),
         nux.LogisticMixtureLogit(n_components=4,
                                  network_kwargs=self.network_kwargs,
                                  reverse=False,
                                  use_condition=True),
         nux.UnitGaussianPrior())
コード例 #3
0
ファイル: image_architectures.py プロジェクト: jxzhangjhu/NuX
def build_architecture(architecture: Sequence[Callable],
                       coupling_algorithm: Callable,
                       actnorm: bool = False,
                       actnorm_axes: Sequence[int] = -1,
                       glow: bool = True,
                       one_dim: bool = False):
    n_squeeze = 0

    layers = []
    for i, layer in list(enumerate(architecture)):

        # We don't want to put anything in front of the squeeze
        if layer == "sq":
            layers.append(nux.Squeeze())
            n_squeeze += 1
            continue

        # Should we do a multiscale factorization?
        if layer == "ms":
            # Recursively build the multiscale
            inner_flow = build_architecture(
                architecture=architecture[i + 1:],
                coupling_algorithm=coupling_algorithm,
                actnorm=actnorm,
                actnorm_axes=actnorm_axes,
                glow=glow,
                one_dim=one_dim)
            layers.append(nux.multi_scale(inner_flow))
            break

        # Actnorm.  Not needed if we're using 1x1 conv because the 1x1
        # conv is initialized with weight normalization so that its outputs
        # have 0 mean and 1 stddev.
        if actnorm:
            layers.append(nux.ActNorm(axis=actnorm_axes))

        # Use a dense connection instead of reverse?
        if glow:
            if one_dim:
                layers.append(nux.AffineLDU())
            else:
                layers.append(nux.OneByOneConv())
        else:
            layers.append(nux.Reverse())

        # Create the layer
        if layer == "chk":
            alg = coupling_algorithm(split_kind="checkerboard")
        elif layer == "chnl":
            alg = coupling_algorithm(split_kind="channel")
        else:
            assert 0
        layers.append(alg)

    # Remember to unsqueeze so that we end up with the same shaped output
    for i in range(n_squeeze):
        layers.append(nux.UnSqueeze())

    return nux.sequential(*layers)
コード例 #4
0
 def block():
     layers = []
     if actnorm:
         layers.append(nux.ActNorm(axis=(-3, -2, -1)))
     if one_by_one_conv:
         layers.append(nux.OneByOneConv())
     layers.append(nux.ResidualFlow(create_network=create_resnet_network))
     return nux.sequential(*layers)
コード例 #5
0
ファイル: augmented.py プロジェクト: C-J-Cundy/NuX
 def block():
   return nux.sequential(nux.RationalQuadraticSpline(K=8,
                                          network_kwargs=self.network_kwargs,
                                          create_network=self.create_network,
                                          use_condition=True,
                                          coupling=True,
                                          condition_method="nin"),
                         nux.OneByOneConv())