def create_flow_model(args): def image_architecture(out_shape): return net.ResNet(out_channel=out_shape[-1], n_blocks=args.n_resnet_blocks, hidden_channel=args.res_block_hidden_channel, nonlinearity="relu", normalization="batch_norm", parameter_norm="weight_norm", block_type="reverse_bottleneck", zero_init=True, use_bias=True, dropout_rate=0.2, gate=False, gate_final=True, squeeze_excite=False) layers = [] layers.append(nux.UniformDequantization()) layers.append(nux.Scale(2**args.quantize_bits)) layers.append(nux.Logit()) layers.append(nux.FlowPlusPlus(n_components=args.n_mixture_components, n_checkerboard_splits_before=args.n_checkerboard_splits_before, n_channel_splits=args.n_channel_splits, n_checkerboard_splits_after=args.n_checkerboard_splits_after, apply_transform_to_both_halves=False, create_network=image_architecture)) layers.append(nux.Flatten()) layers.append(nux.AffineGaussianPriorDiagCov(output_dim=args.output_dim)) return nux.sequential(*layers)
def build_architecture(architecture: Sequence[Callable], coupling_algorithm: Callable, actnorm: bool = False, actnorm_axes: Sequence[int] = -1, glow: bool = True, one_dim: bool = False): n_squeeze = 0 layers = [] for i, layer in list(enumerate(architecture)): # We don't want to put anything in front of the squeeze if layer == "sq": layers.append(nux.Squeeze()) n_squeeze += 1 continue # Should we do a multiscale factorization? if layer == "ms": # Recursively build the multiscale inner_flow = build_architecture( architecture=architecture[i + 1:], coupling_algorithm=coupling_algorithm, actnorm=actnorm, actnorm_axes=actnorm_axes, glow=glow, one_dim=one_dim) layers.append(nux.multi_scale(inner_flow)) break # Actnorm. Not needed if we're using 1x1 conv because the 1x1 # conv is initialized with weight normalization so that its outputs # have 0 mean and 1 stddev. if actnorm: layers.append(nux.ActNorm(axis=actnorm_axes)) # Use a dense connection instead of reverse? if glow: if one_dim: layers.append(nux.AffineLDU()) else: layers.append(nux.OneByOneConv()) else: layers.append(nux.Reverse()) # Create the layer if layer == "chk": alg = coupling_algorithm(split_kind="checkerboard") elif layer == "chnl": alg = coupling_algorithm(split_kind="channel") else: assert 0 layers.append(alg) # Remember to unsqueeze so that we end up with the same shaped output for i in range(n_squeeze): layers.append(nux.UnSqueeze()) return nux.sequential(*layers)
def block(): layers = [] if actnorm: layers.append(nux.ActNorm(axis=(-3, -2, -1))) if one_by_one_conv: layers.append(nux.OneByOneConv()) layers.append(nux.ResidualFlow(create_network=create_resnet_network)) return nux.sequential(*layers)
def block(): return nux.sequential(nux.RationalQuadraticSpline(K=8, network_kwargs=self.network_kwargs, create_network=self.create_network, use_condition=True, coupling=True, condition_method="concat"), nux.AffineLDU(safe_diag=True))
def block(): return nux.sequential(nux.RationalQuadraticSpline(K=8, network_kwargs=self.network_kwargs, create_network=self.create_network, use_condition=True, coupling=True, condition_method="nin"), nux.OneByOneConv())
def default_flow(self): return nux.sequential(nux.Logit(scale=None), nux.OneByOneConv(), nux.reverse_flow(nux.CouplingLogitsticMixtureLogit(n_components=8, network_kwargs=self.network_kwargs, use_condition=True)), nux.OneByOneConv(), nux.reverse_flow(nux.CouplingLogitsticMixtureLogit(n_components=8, network_kwargs=self.network_kwargs, use_condition=True)), nux.UnitGaussianPrior())
def default_decoder(self): # Generate positive values only return nux.sequential( nux.SoftplusInverse(), nux.OneByOneConv(), nux.LogisticMixtureLogit(n_components=4, network_kwargs=self.network_kwargs, reverse=False, use_condition=True), nux.OneByOneConv(), nux.LogisticMixtureLogit(n_components=4, network_kwargs=self.network_kwargs, reverse=False, use_condition=True), nux.UnitGaussianPrior())
def q_ugx(self): if hasattr(self, "_qugx"): return self._qugx # Keep this simple, but a bit more complicated than p(u|z). self._qugx = nux.sequential( nux.reverse_flow( nux.LogisticMixtureLogit(n_components=8, with_affine_coupling=False, coupling=False)), nux.ParametrizedGaussianPrior(network_kwargs=self.network_kwargs, create_network=self.create_network)) return self._qugx
def default_flow(self): def block(): return nux.sequential(nux.RationalQuadraticSpline(K=8, network_kwargs=self.network_kwargs, create_network=self.create_network, use_condition=True, coupling=True, condition_method="nin"), nux.OneByOneConv()) f = nux.repeat(block, n_repeats=3) return nux.sequential(f, nux.ParametrizedGaussianPrior(network_kwargs=self.network_kwargs, create_network=self.create_network))
def default_flow(self): def block(): return nux.sequential(nux.RationalQuadraticSpline(K=8, network_kwargs=self.network_kwargs, create_network=self.create_network, use_condition=True, coupling=True, condition_method="concat"), nux.AffineLDU(safe_diag=True)) f = nux.repeat(block, n_repeats=3) return nux.sequential(nux.reverse_flow(f), nux.ParametrizedGaussianPrior(network_kwargs=self.network_kwargs, create_network=self.create_network))
def create_flow_model(args): layers = [] layers.append(nux.UniformDequantization()) layers.append(nux.Scale(2**args.quantize_bits)) layers.append(nux.Logit()) layers.append( nux.ResidualFlowArchitecture( hidden_channel_size=args.res_flow_hidden_channel_size, actnorm=True, one_by_one_conv=False, repititions=[args.n_resflow_repeats_per_scale] * args.n_resflow_scales)) layers.append(nux.Flatten()) layers.append(nux.GMMPrior(n_classes=10)) return nux.sequential(*layers)
def create_fun(): # def create_network(out_shape): # return net.MLP(out_dim=out_shape[-1], # layer_sizes=[16, 16], # nonlinearity="relu", # parameter_norm="weight_norm", # # parameter_norm="spectral_norm", # dropout_rate=None) def create_network(out_shape): return net.ResNet(out_channel=out_shape[-1], n_blocks=3, hidden_channel=3, nonlinearity="relu", normalization="batch_norm", parameter_norm="weight_norm", block_type="reverse_bottleneck", squeeze_excite=False) # def create_network(out_shape): # return net.CNN(out_channel=out_shape[-1], # n_blocks=1, # hidden_channel=3, # nonlinearity="relu", # normalization=None, # parameter_norm=None, # block_type="reverse_bottleneck", # squeeze_excite=False, # zero_init=False) flat_flow = nux.sequential( PaddingMultiscaleAndChannel(n_squeeze=2, output_channel=1, create_network=create_network), nux.UnitGaussianPrior()) return flat_flow
def ResidualFlowArchitecture(*, hidden_channel_size, actnorm, one_by_one_conv, repititions): if isinstance(repititions, int): repititions = [repititions] def create_resnet_network(out_shape): return net.ReverseBottleneckConv( out_channel=out_shape[-1], hidden_channel=hidden_channel_size, nonlinearity="lipswish", normalization=None, parameter_norm="differentiable_spectral_norm", use_bias=True, dropout_rate=None, gate=False, activate_last=False, max_singular_value=0.999, max_power_iters=1) def block(): layers = [] if actnorm: layers.append(nux.ActNorm(axis=(-3, -2, -1))) if one_by_one_conv: layers.append(nux.OneByOneConv()) layers.append(nux.ResidualFlow(create_network=create_resnet_network)) return nux.sequential(*layers) layers = [] for i, r in enumerate(repititions): if i > 0: layers.append(nux.Squeeze()) layers.append(nux.repeat(block, n_repeats=r)) return nux.sequential(*layers)
def create_fun(should_repeat=True, n_repeats=2): if should_repeat: repeated = repeat(block, n_repeats=n_repeats) else: repeated = nux.sequential(*[block() for _ in range(n_repeats)]) return repeated
def block(): # return ShiftScale() return nux.sequential(Dense(), ShiftScale())