def NoUpsampling(shorten_factor, d_model, *args, **kwargs): del d_model, args, kwargs return core.Fn( 'ReturnZero', lambda x: jnp.zeros( # pylint: disable=g-long-lambda (x.shape[0], x.shape[1] * shorten_factor, x.shape[2]), dtype=x.dtype))
def LinearUpsampling(shorten_factor, d_model, *args, dropout=0.0, mode='train', **kwargs): del args, kwargs return cb.Serial( core.Dense(shorten_factor * d_model), core.Dropout(rate=dropout, mode=mode), core.Fn( 'ProlongBack', lambda x: jnp.reshape( # pylint: disable=g-long-lambda # Prolong back. # pylint: disable=g-long-lambda x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1) )
def LinearPooling(shorten_factor, d_model, *args, dropout=0.0, mode='train', **kwargs): del args, kwargs return cb.Serial( core.Fn( 'Shorten', lambda x: jnp.reshape( # pylint: disable=g-long-lambda # Shorten -- move to depth. # pylint: disable=g-long-lambda x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), core.Dense(d_model), core.Dropout(rate=dropout, mode=mode) )
def Rotate(): # pylint: disable=invalid-name return core.Fn('Rotate', rotate)
def NaiveUpsampling(shorten_factor, d_model, *args, **kwargs): # pylint: disable = unused-argument return core.Fn('Repeat', lambda x: jnp.repeat(x, shorten_factor, axis=1))