Beispiel #1
0
def NoUpsampling(shorten_factor, d_model, *args, **kwargs):
    del d_model, args, kwargs

    return core.Fn(
        'ReturnZero',
        lambda x: jnp.zeros(  # pylint: disable=g-long-lambda
            (x.shape[0], x.shape[1] * shorten_factor, x.shape[2]),
            dtype=x.dtype))
Beispiel #2
0
def LinearUpsampling(shorten_factor, d_model, *args, dropout=0.0, mode='train',
                     **kwargs):
  del args, kwargs

  return cb.Serial(
      core.Dense(shorten_factor * d_model),
      core.Dropout(rate=dropout, mode=mode),
      core.Fn(
          'ProlongBack',
          lambda x: jnp.reshape(  # pylint: disable=g-long-lambda
              # Prolong back.  # pylint: disable=g-long-lambda
              x, (x.shape[0], x.shape[1] * shorten_factor, -1)),
          n_out=1)
  )
Beispiel #3
0
def LinearPooling(shorten_factor, d_model, *args, dropout=0.0, mode='train',
                  **kwargs):
  del args, kwargs

  return cb.Serial(
      core.Fn(
          'Shorten',
          lambda x: jnp.reshape(  # pylint: disable=g-long-lambda
              # Shorten -- move to depth.  # pylint: disable=g-long-lambda
              x, (x.shape[0], x.shape[1] // shorten_factor, -1)),
          n_out=1),
      core.Dense(d_model),
      core.Dropout(rate=dropout, mode=mode)
  )
def Rotate():  # pylint: disable=invalid-name
    return core.Fn('Rotate', rotate)
Beispiel #5
0
def NaiveUpsampling(shorten_factor, d_model, *args,
                    **kwargs):  # pylint: disable = unused-argument
    return core.Fn('Repeat', lambda x: jnp.repeat(x, shorten_factor, axis=1))