Esempio n. 1
0
def LocallyConvDense(n_modules, n_units, kernel_size=1, length_kernel_size=1):
    """Layer using local convolutions for approximation of Dense layer.

  The layer splits the last axis of a tensor into `n_modules`, then runs
  a convolution on all those modules, and concatenates their results.
  It is similar to LocallyConnectedDense above, but shares weights.

  Args:
    n_modules: Indicates how many modules (pixels) should be input and output
        split into for processing.
    n_units: how many outputs (filters) should each module generate.
    kernel_size: The size of the kernel to be used.
    length_kernel_size: If > 1, also do causal convolution on the previous axis,
      which is often the sentence length in sequence models.

  Returns:
      LocallyConvDense base.Layer.
  """
    if n_modules == 1:
        return tl.Dense(n_units)
    if kernel_size % 2 != 1:
        raise ValueError('Currently we only handle odd kernel sizes.')
    half = (kernel_size - 1) // 2
    pad_widths = [[0, 0], [length_kernel_size - 1, 0], [half, half], [0, 0]]
    return tl.Serial(
        tl.SplitLastAxis(n_modules),
        tl.Fn('Pad', lambda x: jnp.pad(x, pad_width=pad_widths)),
        tl.Conv(n_units, kernel_size=(length_kernel_size, kernel_size)),
        tl.MergeLastTwoAxes())
Esempio n. 2
0
def LocallyConnectedDense(
        n_modules,
        n_units,
        kernel_size=1,  # pylint: disable=invalid-name
        kernel_initializer=init.GlorotUniformInitializer(),
        bias_initializer=init.RandomNormalInitializer(1e-6),
        use_bias=True):
    """Layer using LocallyConnected1d for approximation of Dense layer.

  The layer splits the last axis of a tensor into `n_modules`, then runs
  LocallyConnected1d (grouped convolution) on all those modules, and
  concatenates their results. It is essentially a locally-sensitive
  approximation of Dense layer, with number of parameters smaller by the factor
  of `n_modules / kernel_size`.

  Args:
    n_modules: Indicates how many modules (pixels) should be input and output
        split into for processing.
    n_units: how many outputs (filters) should each module generate.
    kernel_size: The size of the kernel to be used.
    kernel_initializer: Function that creates a matrix of (random) initial
        connection weights `W` for the layer.
    bias_initializer: Function that creates a vector of (random) initial
        bias weights `b` for the layer.
    use_bias: If `True`, compute an affine map `y = Wx + b`; else compute
        a linear map `y = Wx`.

  Returns:
      LocallyConnectedDense base.Layer.
  """
    if n_modules == 1:
        return tl.Dense(n_units,
                        kernel_initializer=kernel_initializer,
                        bias_initializer=bias_initializer,
                        use_bias=use_bias)
    return tl.Serial(
        tl.SplitLastAxis(n_modules),
        tl.LocallyConnected1d(n_units,
                              kernel_size,
                              kernel_initializer=kernel_initializer,
                              bias_initializer=bias_initializer,
                              use_bias=use_bias,
                              padding='WRAP'), tl.MergeLastTwoAxes())
Esempio n. 3
0
def LocallyConvDense(n_modules, n_units, kernel_size=1):
    """Layer using local convolutions for approximation of Dense layer.

  The layer splits the last axis of a tensor into `n_modules`, then runs
  a convolution on all those modules, and concatenates their results.
  It is similar to LocallyConnectedDense above, but shares weights.

  Args:
    n_modules: Indicates how many modules (pixels) should be input and output
        split into for processing.
    n_units: how many outputs (filters) should each module generate.
    kernel_size: The size of the kernel to be used.

  Returns:
      LocallyConvDense base.Layer.
  """
    if n_modules == 1:
        return tl.Dense(n_units)
    return tl.Serial(
        tl.SplitLastAxis(n_modules),
        tl.Conv(n_units, kernel_size=(1, kernel_size), padding='SAME'),
        tl.MergeLastTwoAxes())