Пример #1
0
  def __init__(self,
               input_shape,
               blockwise_splits,
               coupling_bijector_fn=None):
    """Creates the exit bijector.

    Args:
      input_shape: A list specifying the input shape to the exit bijector.
        Used in constructing the network.
      blockwise_splits: A list of integers specifying the number of channels
        exiting the model, as well as those being left in the model, and those
        bypassing the exit bijector altogether.
      coupling_bijector_fn: A function which takes the argument `input_shape`
        and returns a callable neural network (e.g. a keras Sequential). The
        network should either return a tensor with the same event shape as
        `input_shape` (this will employ additive coupling), a tensor with the
        same height and width as `input_shape` but twice the number of channels
        (this will employ affine coupling), or a bijector which takes in a
        tensor with event shape `input_shape`, and returns a tensor with shape
        `input_shape`.
    """

    nleave, ngrab, npass = blockwise_splits

    new_input_shape = input_shape[:-1]+(nleave,)
    target_output_shape = input_shape[:-1]+(ngrab,)

    # if nleave or ngrab == 0, then just use an identity for everything.
    if nleave == 0 or ngrab == 0:
      exit_layer = None
      exit_bijector_fn = None

      self.exit_layer = exit_layer
      shift_distribution = identity.Identity()

    else:
      exit_layer = coupling_bijector_fn(new_input_shape,
                                        output_chan=ngrab)
      exit_bijector_fn = self.make_bijector_fn(
          exit_layer,
          target_shape=target_output_shape,
          scale_fn=tf.exp)
      self.exit_layer = exit_layer  # For variable tracking.
      shift_distribution = real_nvp.RealNVP(
          num_masked=nleave,
          bijector_fn=exit_bijector_fn)

    super(ExitBijector, self).__init__(
        [shift_distribution, identity.Identity()], [nleave + ngrab, npass])
Пример #2
0
    def __init__(self, input_shape, num_steps, coupling_bijector_fn,
                 use_actnorm, seedstream):
        parameters = dict(locals())
        rnvp_block = [identity.Identity()]
        this_nchan = input_shape[-1]

        for j in range(num_steps):  # pylint: disable=unused-variable

            this_layer_input_shape = input_shape[:-1] + (input_shape[-1] //
                                                         2, )
            this_layer = coupling_bijector_fn(this_layer_input_shape)
            bijector_fn = self.make_bijector_fn(this_layer)

            # For each step in the block, we do (optional) actnorm, followed
            # by an invertible 1x1 convolution, then affine coupling.
            this_rnvp = invert.Invert(
                real_nvp.RealNVP(this_nchan // 2, bijector_fn=bijector_fn))

            # Append the layer to the realNVP bijector for variable tracking.
            this_rnvp.coupling_bijector_layer = this_layer
            rnvp_block.append(this_rnvp)

            rnvp_block.append(
                invert.Invert(
                    OneByOneConv(this_nchan,
                                 seed=seedstream(),
                                 dtype=dtype_util.common_dtype(
                                     this_rnvp.variables,
                                     dtype_hint=tf.float32))))

            if use_actnorm:
                rnvp_block.append(
                    ActivationNormalization(this_nchan,
                                            dtype=dtype_util.common_dtype(
                                                this_rnvp.variables,
                                                dtype_hint=tf.float32)))

        # Note that we reverse the list since Chain applies bijectors in reverse
        # order.
        super(GlowBlock, self).__init__(chain.Chain(rnvp_block[::-1]),
                                        parameters=parameters,
                                        name='glow_block')