예제 #1
0
파일: celeba.py 프로젝트: C-J-Cundy/NuX
def create_flow_model(args):

  def image_architecture(out_shape):
    return net.ResNet(out_channel=out_shape[-1],
                      n_blocks=args.n_resnet_blocks,
                      hidden_channel=args.res_block_hidden_channel,
                      nonlinearity="relu",
                      normalization="batch_norm",
                      parameter_norm="weight_norm",
                      block_type="reverse_bottleneck",
                      zero_init=True,
                      use_bias=True,
                      dropout_rate=0.2,
                      gate=False,
                      gate_final=True,
                      squeeze_excite=False)

  layers = []
  layers.append(nux.UniformDequantization())
  layers.append(nux.Scale(2**args.quantize_bits))
  layers.append(nux.Logit())
  layers.append(nux.FlowPlusPlus(n_components=args.n_mixture_components,
                                 n_checkerboard_splits_before=args.n_checkerboard_splits_before,
                                 n_channel_splits=args.n_channel_splits,
                                 n_checkerboard_splits_after=args.n_checkerboard_splits_after,
                                 apply_transform_to_both_halves=False,
                                 create_network=image_architecture))

  layers.append(nux.Flatten())
  layers.append(nux.AffineGaussianPriorDiagCov(output_dim=args.output_dim))

  return nux.sequential(*layers)
예제 #2
0
def build_architecture(architecture: Sequence[Callable],
                       coupling_algorithm: Callable,
                       actnorm: bool = False,
                       actnorm_axes: Sequence[int] = -1,
                       glow: bool = True,
                       one_dim: bool = False):
    n_squeeze = 0

    layers = []
    for i, layer in list(enumerate(architecture)):

        # We don't want to put anything in front of the squeeze
        if layer == "sq":
            layers.append(nux.Squeeze())
            n_squeeze += 1
            continue

        # Should we do a multiscale factorization?
        if layer == "ms":
            # Recursively build the multiscale
            inner_flow = build_architecture(
                architecture=architecture[i + 1:],
                coupling_algorithm=coupling_algorithm,
                actnorm=actnorm,
                actnorm_axes=actnorm_axes,
                glow=glow,
                one_dim=one_dim)
            layers.append(nux.multi_scale(inner_flow))
            break

        # Actnorm.  Not needed if we're using 1x1 conv because the 1x1
        # conv is initialized with weight normalization so that its outputs
        # have 0 mean and 1 stddev.
        if actnorm:
            layers.append(nux.ActNorm(axis=actnorm_axes))

        # Use a dense connection instead of reverse?
        if glow:
            if one_dim:
                layers.append(nux.AffineLDU())
            else:
                layers.append(nux.OneByOneConv())
        else:
            layers.append(nux.Reverse())

        # Create the layer
        if layer == "chk":
            alg = coupling_algorithm(split_kind="checkerboard")
        elif layer == "chnl":
            alg = coupling_algorithm(split_kind="channel")
        else:
            assert 0
        layers.append(alg)

    # Remember to unsqueeze so that we end up with the same shaped output
    for i in range(n_squeeze):
        layers.append(nux.UnSqueeze())

    return nux.sequential(*layers)
 def block():
     layers = []
     if actnorm:
         layers.append(nux.ActNorm(axis=(-3, -2, -1)))
     if one_by_one_conv:
         layers.append(nux.OneByOneConv())
     layers.append(nux.ResidualFlow(create_network=create_resnet_network))
     return nux.sequential(*layers)
예제 #4
0
파일: augmented.py 프로젝트: C-J-Cundy/NuX
 def block():
   return nux.sequential(nux.RationalQuadraticSpline(K=8,
                                          network_kwargs=self.network_kwargs,
                                          create_network=self.create_network,
                                          use_condition=True,
                                          coupling=True,
                                          condition_method="concat"),
                         nux.AffineLDU(safe_diag=True))
예제 #5
0
파일: augmented.py 프로젝트: C-J-Cundy/NuX
 def block():
   return nux.sequential(nux.RationalQuadraticSpline(K=8,
                                          network_kwargs=self.network_kwargs,
                                          create_network=self.create_network,
                                          use_condition=True,
                                          coupling=True,
                                          condition_method="nin"),
                         nux.OneByOneConv())
예제 #6
0
 def default_flow(self):
   return nux.sequential(nux.Logit(scale=None),
                         nux.OneByOneConv(),
                         nux.reverse_flow(nux.CouplingLogitsticMixtureLogit(n_components=8,
                                                                            network_kwargs=self.network_kwargs,
                                                                            use_condition=True)),
                         nux.OneByOneConv(),
                         nux.reverse_flow(nux.CouplingLogitsticMixtureLogit(n_components=8,
                                                                            network_kwargs=self.network_kwargs,
                                                                            use_condition=True)),
                         nux.UnitGaussianPrior())
예제 #7
0
 def default_decoder(self):
     # Generate positive values only
     return nux.sequential(
         nux.SoftplusInverse(), nux.OneByOneConv(),
         nux.LogisticMixtureLogit(n_components=4,
                                  network_kwargs=self.network_kwargs,
                                  reverse=False,
                                  use_condition=True), nux.OneByOneConv(),
         nux.LogisticMixtureLogit(n_components=4,
                                  network_kwargs=self.network_kwargs,
                                  reverse=False,
                                  use_condition=True),
         nux.UnitGaussianPrior())
예제 #8
0
    def q_ugx(self):
        if hasattr(self, "_qugx"):
            return self._qugx

        # Keep this simple, but a bit more complicated than p(u|z).
        self._qugx = nux.sequential(
            nux.reverse_flow(
                nux.LogisticMixtureLogit(n_components=8,
                                         with_affine_coupling=False,
                                         coupling=False)),
            nux.ParametrizedGaussianPrior(network_kwargs=self.network_kwargs,
                                          create_network=self.create_network))
        return self._qugx
예제 #9
0
파일: augmented.py 프로젝트: C-J-Cundy/NuX
  def default_flow(self):

    def block():
      return nux.sequential(nux.RationalQuadraticSpline(K=8,
                                             network_kwargs=self.network_kwargs,
                                             create_network=self.create_network,
                                             use_condition=True,
                                             coupling=True,
                                             condition_method="nin"),
                            nux.OneByOneConv())

    f = nux.repeat(block, n_repeats=3)
    return nux.sequential(f,
                          nux.ParametrizedGaussianPrior(network_kwargs=self.network_kwargs,
                                                        create_network=self.create_network))
예제 #10
0
파일: augmented.py 프로젝트: C-J-Cundy/NuX
  def default_flow(self):

    def block():
      return nux.sequential(nux.RationalQuadraticSpline(K=8,
                                             network_kwargs=self.network_kwargs,
                                             create_network=self.create_network,
                                             use_condition=True,
                                             coupling=True,
                                             condition_method="concat"),
                            nux.AffineLDU(safe_diag=True))

    f = nux.repeat(block, n_repeats=3)

    return nux.sequential(nux.reverse_flow(f),
                          nux.ParametrizedGaussianPrior(network_kwargs=self.network_kwargs,
                                                        create_network=self.create_network))
예제 #11
0
def create_flow_model(args):

    layers = []
    layers.append(nux.UniformDequantization())
    layers.append(nux.Scale(2**args.quantize_bits))
    layers.append(nux.Logit())
    layers.append(
        nux.ResidualFlowArchitecture(
            hidden_channel_size=args.res_flow_hidden_channel_size,
            actnorm=True,
            one_by_one_conv=False,
            repititions=[args.n_resflow_repeats_per_scale] *
            args.n_resflow_scales))
    layers.append(nux.Flatten())
    layers.append(nux.GMMPrior(n_classes=10))

    return nux.sequential(*layers)
예제 #12
0
    def create_fun():

        # def create_network(out_shape):
        #   return net.MLP(out_dim=out_shape[-1],
        #                  layer_sizes=[16, 16],
        #                  nonlinearity="relu",
        #                  parameter_norm="weight_norm",
        #                  # parameter_norm="spectral_norm",
        #                  dropout_rate=None)

        def create_network(out_shape):
            return net.ResNet(out_channel=out_shape[-1],
                              n_blocks=3,
                              hidden_channel=3,
                              nonlinearity="relu",
                              normalization="batch_norm",
                              parameter_norm="weight_norm",
                              block_type="reverse_bottleneck",
                              squeeze_excite=False)

        # def create_network(out_shape):
        #   return net.CNN(out_channel=out_shape[-1],
        #                  n_blocks=1,
        #                  hidden_channel=3,
        #                  nonlinearity="relu",
        #                  normalization=None,
        #                  parameter_norm=None,
        #                  block_type="reverse_bottleneck",
        #                  squeeze_excite=False,
        #                  zero_init=False)

        flat_flow = nux.sequential(
            PaddingMultiscaleAndChannel(n_squeeze=2,
                                        output_channel=1,
                                        create_network=create_network),
            nux.UnitGaussianPrior())
        return flat_flow
def ResidualFlowArchitecture(*, hidden_channel_size, actnorm, one_by_one_conv,
                             repititions):

    if isinstance(repititions, int):
        repititions = [repititions]

    def create_resnet_network(out_shape):
        return net.ReverseBottleneckConv(
            out_channel=out_shape[-1],
            hidden_channel=hidden_channel_size,
            nonlinearity="lipswish",
            normalization=None,
            parameter_norm="differentiable_spectral_norm",
            use_bias=True,
            dropout_rate=None,
            gate=False,
            activate_last=False,
            max_singular_value=0.999,
            max_power_iters=1)

    def block():
        layers = []
        if actnorm:
            layers.append(nux.ActNorm(axis=(-3, -2, -1)))
        if one_by_one_conv:
            layers.append(nux.OneByOneConv())
        layers.append(nux.ResidualFlow(create_network=create_resnet_network))
        return nux.sequential(*layers)

    layers = []
    for i, r in enumerate(repititions):
        if i > 0:
            layers.append(nux.Squeeze())

        layers.append(nux.repeat(block, n_repeats=r))

    return nux.sequential(*layers)
예제 #14
0
 def create_fun(should_repeat=True, n_repeats=2):
     if should_repeat:
         repeated = repeat(block, n_repeats=n_repeats)
     else:
         repeated = nux.sequential(*[block() for _ in range(n_repeats)])
     return repeated
예제 #15
0
 def block():
     # return ShiftScale()
     return nux.sequential(Dense(), ShiftScale())