예제 #1
0
파일: rl.py 프로젝트: stephenjfox/trax
 def ActionInjector(mode):
     if inject_actions:
         if is_discrete:
             action_encoder = tl.Embedding(vocab_size, inject_actions_dim)
         else:
             action_encoder = tl.Dense(inject_actions_dim)
         encoders = tl.Parallel(
             tl.Dense(inject_actions_dim),
             action_encoder,
         )
         if multiplicative_action_injection:
             action_injector = tl.Serial(
                 tl.Fn('TanhMulGate', lambda x, a: x * jnp.tanh(a)),
                 tl.LayerNorm()  # compensate for reduced variance
             )
         else:
             action_injector = tl.Add()
         return tl.Serial(
             # Input: (body output, actions).
             encoders,
             action_injector,
             models.MLP(
                 layer_widths=(inject_actions_dim, ) *
                 inject_actions_n_layers,
                 out_activation=True,
                 flatten=False,
                 mode=mode,
             ))
     else:
         return []
예제 #2
0
def ResidualZero(*layers, shortcut=None):
    """Wraps a series of layers with a ReZero-style residual connection.

  Instead of computing `(shortcut) + (output of layers)`, like in classical
  Residual connection, ResidualZero computes
  `(shortcut) + alpha * (output of layers)`, where `alpha` is a learnable scalar
  initialized with zero.

  Args:
    *layers: One or more layers, to be applied in series.
    shortcut: If None (the usual case), the Residual layer computes the
        element-wise sum of the stack-top input with the output of the layer
        series. If specified, the `shortcut` layer applies to a copy of the
        inputs and (elementwise) adds its output to the output from the main
        layer series.

  Returns:
      A layer representing a residual connection paired with a layer series.
  """
    layers = _ensure_flat(layers)
    layer = layers[0] if len(layers) == 1 else tl.Serial(layers)
    # TODO(jaszczur): perhaps change inner Serial to Branch?
    return tl.Serial(
        tl.Branch(
            shortcut,
            tl.Serial(
                layer,
                tl.Weights(
                    lambda shape, rng: jnp.zeros(shape, dtype=jnp.float32)),
                tl.Multiply())),
        tl.Add(),  # pylint: disable=no-value-for-parameter
    )
예제 #3
0
파일: sparsity.py 프로젝트: piotrekp1/trax
def EinsumDense(d_input, d_output, use_bias):
    """Returns a reimplementation of Dense layer, using einsum.

  While this is an equivalent of a Dense layer, it seems to be faster when used
  in decoding if used with bias (see decoding_timing_test.py ).
  This layer can be removed when we understand better the reason for the
  difference in decoding speed.

  Args:
    d_input: Dimensionality of the input tensor.
    d_output: Dimensionality of the output tensor.
    use_bias: Whether to use bias.
  """
    layers = [
        tl.Weights(init.GlorotUniformInitializer(), [d_output, d_input]),
        tl.Fn(
            'EinsumDense',
            (
                lambda kernel, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('xd,...d->...x', kernel, embeds)))
    ]
    if use_bias:
        layers.extend([
            tl.Weights(init.RandomNormalInitializer(1e-6), [d_output]),
            tl.Add()
        ])
    return tl.Serial(layers)
예제 #4
0
  def __init__(self, pre_attention, attention, post_attention):
    self.pre_attention = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(pre_attention, [], []),
    )
    assert hasattr(attention, 'forward_and_backward')
    self.attention = ApplyAttentionWrapper(attention)
    self.post_attention = tl.Parallel(post_attention, [], [])

    layers = [
        self.pre_attention,
        self.attention,
        self.post_attention,
        tl.Parallel(tl.Add(), []),
    ]
    super(ReversibleAttentionHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
    self.reverse_layers = [
        self.pre_attention,
        self.attention,
        self.post_attention,
        self.subtract_top,
    ]
예제 #5
0
 def test_add_div(self):
   layer = tl.Branch(tl.Add(), DivideBy(0.5))
   xs = [np.array([1, 2, 3]),
         np.array([10, 20, 30])]
   ys = layer(xs)
   self.assertEqual(as_list(ys), [[11, 22, 33],
                                  [2, 4, 6]])
예제 #6
0
    def test_run_reversible_same_as_default_extended(self):
        """Runs the reversible trainer, check results are the same as default."""
        inputs_batch = np.arange(8).reshape((2, 4))
        targets_batch = 2 * inputs_batch
        labeled_batch = (inputs_batch, targets_batch,
                         np.ones_like(targets_batch))
        # We want to test rng propagation too, so adding some dropout layers.
        first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup())
        rev_layers1 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)),
            tl.ReversibleSwap(),
            tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)),
            tl.ReversibleSwap()
        ]
        mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup())
        rev_layers2 = [
            tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)),
            tl.ReversibleSwap()
        ]
        loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3),
                               tl.LogSoftmax(), tl.CrossEntropyLoss())
        model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] +
                          rev_layers2 + [loss_layer])
        rng_init = fastmath.random.get_prng(12)
        model.init(labeled_batch, rng=rng_init)
        optimizer_fn = optimizers.Adam  # to test slots

        # Make 3 steps with the original trainer.
        optimizer = optimizer_fn()
        optimizer.tree_init(model.weights)
        trainer = optimizers.Trainer(model, optimizer)
        rng_step1 = fastmath.random.get_prng(7)
        rng_step2 = fastmath.random.get_prng(8)
        rng_step3 = fastmath.random.get_prng(9)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)
        first_layer_weights1 = first_layer.weights
        rev_layer12_weights1 = rev_layers1[2].weights
        mid_layer_weights1 = mid_layer.weights
        rev_layer20_weights1 = rev_layers2[0].weights
        loss_layer_weights1 = loss_layer.weights

        # Now make 3 steps with reversible trainer.
        model.init(labeled_batch, rng=rng_init)
        trainer = optimizers.ReversibleSerialTrainer(
            [(first_layer.sublayers, rev_layers1),
             (mid_layer.sublayers, rev_layers2)], loss_layer, optimizer_fn)
        trainer.one_step(labeled_batch, rng_step1)
        trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02)
        trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03)

        # Check that weights end up the same.
        self._assert_all_equal(loss_layer_weights1, loss_layer.weights)
        self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights)
        self._assert_all_equal(mid_layer_weights1, mid_layer.weights)
        self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights)
        self._assert_all_equal(first_layer_weights1, first_layer.weights)
예제 #7
0
def BERTPretrainingLoss():
    nsp_loss = [
        tl.Select([0, 2, 3], n_in=6),
        tl.WeightedCategoryCrossEntropy()
    ]
    mlm_loss = [
        tl.Select([1, 4, 5], n_in=6),
        tl.WeightedCategoryCrossEntropy()
    ]
    return tl.Serial(tl.Branch(nsp_loss, mlm_loss), tl.Add())
예제 #8
0
파일: reformer.py 프로젝트: syyunn/trax
    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial(
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(residual_layers, [], []),
        )

        layers = [self.compute_residual, tl.Parallel(tl.Add(), [])]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [self.compute_residual, self.subtract_top]
예제 #9
0
파일: sparsity.py 프로젝트: piotrekp1/trax
def MultiplicativeSparseDense(sparsity,
                              d_input,
                              d_output=None,
                              use_bias=True,
                              use_bfloat16=False):
    """Returns a replacement of Dense layer which uses less parameters.

  The layer uses number of modules equal to `sparsity`. It multiplies each
  dimension of the input tensor by a scalar specific to each dimension and each
  module separately; then it applies Dense(d_output/sparsity) to each module.
  Compared to standard dense layer, MultiplicativeSparseDense uses less
  parameters while still being able to express many interesting functions (for
  example a permutation).

  Args:
    sparsity: The sparsity of the layer; the output vector is divided into this
        number of modules.
    d_input: Dimensionality of input tensor.
    d_output: Dimensionality of output tensor; by default equal to d_input.
    use_bias: Whether to use bias.
    use_bfloat16: Whether to use bfloat16 for weights.
  """

    assert d_output % sparsity == 0
    d_module = d_output // sparsity

    layers = [
        # Weight below is used for per-head preprocessing of an embedding.
        tl.Weights(init.RandomNormalInitializer(stddev=0.5),
                   shape=[sparsity, d_input],
                   use_bfloat16=use_bfloat16),
        # Weight below is dense kernel, shared across heads.
        tl.Weights(init.GlorotUniformInitializer(), [d_input, d_module],
                   use_bfloat16=use_bfloat16),
        # To save memory the per-head preprocessing and multiplying by the
        # kernel is done in the same einsum.
        tl.Fn(
            'AttentionEinsum',
            (
                lambda kernel, multiplier, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('dx,hd,...d->...hx', kernel, multiplier, embeds))),
        MergeLastTwoAxes(),
    ]
    if use_bias:
        layers.extend([
            # Weight below is bias after dense, per-head.
            tl.Weights(init.RandomNormalInitializer(1e-6), [d_output],
                       use_bfloat16=use_bfloat16),
            tl.Add(),
        ])
    return tl.Serial(layers)
예제 #10
0
 def loss(id_to_mask=None, has_weights=False):
   """Cross-entropy loss as scalar compatible with Trax masking."""
   return layers.Serial(
       # Swap from (pred-obs, pred-reward, target-obs, target-reward)
       # to (pred-obs, target-obs, pred-reward, target-reward).
       layers.Parallel([], layers.Swap()),
       # Cross-entropy loss for obs, L2 loss on reward.
       layers.Parallel(layers.CrossEntropyLoss(id_to_mask, has_weights),
                       layers.L2Loss(id_to_mask, has_weights)),
       # Add both losses.
       layers.Add(),
       # Zero out in this test.
       layers.Fn(lambda x: x * 0.0),
   )
예제 #11
0
 def loss(mask_id=None, has_weights=False):
     """Cross-entropy loss as scalar compatible with Trax masking."""
     return layers.Serial(
         # Swap from (pred-obs, pred-reward, target-obs, target-reward)
         # to (pred-obs, target-obs, pred-reward, target-reward).
         layers.Parallel([], layers.Swap()),
         # Cross-entropy loss for obs, L2 loss on reward.
         layers.Parallel(
             layers.CrossEntropyLossScalar(mask_id, has_weights),
             layers.L2LossScalar(mask_id, has_weights)),
         # Add both losses.
         layers.Add(),
         # Zero out in this test.
         layers.MulConstant(constant=0.0))
예제 #12
0
 def loss():
     """Cross-entropy loss as scalar compatible with Trax masking."""
     ones = layers.Fn(lambda x: math.numpy.ones_like(x))  # pylint: disable=unnecessary-lambda
     return layers.Serial(
         # Swap from (pred-obs, pred-reward, target-obs, target-reward)
         # to (pred-obs, target-obs, pred-reward, target-reward).
         layers.Parallel([], layers.Swap()),
         # Duplicate target-obs and target-reward and make 1 to add weights.
         layers.Parallel([], layers.Branch([], ones)),
         layers.Parallel([], [], [], [], layers.Branch([], ones)),
         # Cross-entropy loss for obs, L2 loss on reward.
         layers.Parallel(layers.CrossEntropyLoss(), layers.L2Loss()),
         # Add both losses.
         layers.Add(),
         # Zero out in this test.
         layers.Fn(lambda x: x * 0.0),
     )
예제 #13
0
  def __init__(self, residual_layers):
    self.compute_residual = tl.Serial(         # x1_or_y1, x2,           ...
        tl.Select([1, 0, 1]),                  # x2, x1_or_y1, x2,       ...
        tl.Parallel([], [], residual_layers),  # x2, x1_or_y1, residual, ...
        tl.Select([2, 1, 0]),                  # residual, x1_or_y1, x2, ...
    )

    self.n_preserve = self.compute_residual.n_out - 2
    parallel_preserve = [[]] * self.n_preserve

    layers = [
        self.compute_residual,
        tl.Parallel(tl.Add(), *parallel_preserve)
    ]
    super(ReversibleHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), *parallel_preserve)
    self.reverse_layers = [self.compute_residual, self.subtract_top]
예제 #14
0
파일: rl.py 프로젝트: elliotthwang/trax
 def ActionInjector(mode):
     if inject_actions:
         return tl.Serial(
             # Input: (body output, actions).
             tl.Parallel(
                 tl.Dense(inject_actions_dim),
                 tl.Dense(inject_actions_dim),
             ),
             tl.Add(),
             models.PureMLP(
                 layer_widths=(inject_actions_dim, ) *
                 inject_actions_n_layers,
                 out_activation=True,
                 flatten=False,
                 mode=mode,
             ))
     else:
         return []
예제 #15
0
파일: sparsity.py 프로젝트: piotrekp1/trax
def MultiplicativeModularSparseDense(sparsity, d_feature):
    """Returns a replacement of Dense layer which uses less parameters.

  The layer uses number of modules equal to `sparsity`. It is a combination of
  multiplicative dense and locally connected dense layers.

  Args:
    sparsity: The sparsity of the layer; the output vector is divided into this
        number of modules.
    d_feature: Dimensionality of input and output tensor.
  """

    assert d_feature % sparsity == 0
    d_module = d_feature // sparsity

    return tl.Serial(
        # Weight below is used for per-head preprocessing of an embedding.
        tl.Weights(init.RandomNormalInitializer(stddev=0.5),
                   shape=[sparsity, d_feature]),
        # Weight below is a kernel of multiplicative dense, shared across heads.
        tl.Weights(init.GlorotUniformInitializer(), [d_feature, d_module]),
        # Weight below is a kernel of modular dense.
        tl.Weights(
            functools.partial(init.GlorotUniformInitializer(),
                              nonreceptive_dims=[0]),
            [sparsity, d_module, d_module]),
        # To save memory the per-head preprocessing and multiplying by
        # kernels is done in a single einsum.
        tl.Fn(
            'SparseDenseEinsum',
            (
                lambda kmod, kmult, multiplier, embeds:  # pylint: disable=g-long-lambda
                jnp.einsum('hxo,dx,hd,...d->...ho', kmod, kmult, multiplier,
                           embeds))),
        MergeLastTwoAxes(),
        # Weight below is bias after dense, per-head.
        tl.Weights(init.RandomNormalInitializer(1e-6), [d_feature]),
        tl.Add(),
    )
예제 #16
0
파일: rl.py 프로젝트: yangcaot/trax
 def ActionInjector(mode):
     if inject_actions:
         if is_discrete:
             encode_layer = tl.Parallel(
                 tl.Dense(inject_actions_dim),
                 tl.Embedding(inject_actions_dim, vocab_size=vocab_size))
         else:
             encode_layer = tl.Parallel(
                 tl.Dense(inject_actions_dim),
                 tl.Dense(inject_actions_dim),
             )
         return tl.Serial(
             # Input: (body output, actions).
             encode_layer,
             tl.Add(),
             models.PureMLP(
                 layer_widths=(inject_actions_dim, ) *
                 inject_actions_n_layers,
                 out_activation=True,
                 flatten=False,
                 mode=mode,
             ))
     else:
         return []
예제 #17
0
def BERT(d_model=768,
         vocab_size=30522,
         max_len=512,
         type_vocab_size=2,
         n_heads=12,
         d_ff=3072,
         n_layers=12,
         head=None,
         init_checkpoint=None,
         mode='eval',
        ):
  """BERT (default hparams are for bert-base-uncased)."""
  layer_norm_eps = 1e-12
  d_head = d_model // n_heads

  word_embeddings = tl.Embedding(d_model, vocab_size)
  type_embeddings = tl.Embedding(d_model, type_vocab_size)
  position_embeddings = tl.PositionalEncoding(max_len, mode=mode)
  embeddings = [
      tl.Select([0, 1, 0], n_in=3),  # Drops 'idx' input.
      tl.Parallel(
          word_embeddings,
          type_embeddings,
          [tl.PaddingMask(),
           tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)]
      ),
      tl.Add(),
      position_embeddings,
      tl.LayerNorm(epsilon=layer_norm_eps),
  ]

  encoder = []
  for _ in range(n_layers):
    attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head,
                            bias=True, masked=True, mode=mode)
    feed_forward = [
        tl.Dense(d_ff),
        tl.Gelu(),
        tl.Dense(d_model)
    ]
    encoder += [
        tl.Select([0, 1, 1]),  # Save a copy of the mask
        tl.Residual(attn, AddBias()),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(epsilon=layer_norm_eps),
        tl.Residual(*feed_forward),
        tl.LayerNorm(epsilon=layer_norm_eps),
    ]

  encoder += [tl.Select([0], n_in=2)]  # Drop the mask

  pooler = [
      tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2),
      tl.Dense(d_model),
      tl.Tanh(),
  ]

  init_checkpoint = init_checkpoint if mode == 'train' else None
  bert = PretrainedBERT(
      embeddings + encoder + pooler, init_checkpoint=init_checkpoint)

  if head is not None:
    bert = tl.Serial(bert, head())

  return bert
예제 #18
0
 def test_default_name(self):
     layer = tl.Branch(tl.Add(), DivideBy(0.5))
     self.assertIn('Branch', str(layer))
예제 #19
0
 def test_printing_sublayers(self):
     layer = tl.Branch(tl.Add(), tl.Add())
     expected_result = 'Branch_in2_out2[\n  Add_in2\n  Add_in2\n]'
     self.assertEqual(expected_result, str(layer))
예제 #20
0
def LatentTransformer(input_vocab_size,
                      output_vocab_size=None,
                      d_model=512,
                      d_ff=2048,
                      n_encoder_layers=6,
                      n_decoder_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      dropout_shared_axes=None,
                      max_len=2048,
                      mode='train',
                      ff_activation=tl.Relu,
                      axial_pos_shape=None,
                      d_axial_pos_embs=None):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(input_vocab_size,
                                           d_model,
                                           mode,
                                           dropout,
                                           dropout_shared_axes,
                                           max_len,
                                           output_vocab_size=output_vocab_size,
                                           axial_pos_shape=axial_pos_shape,
                                           d_axial_pos_embs=d_axial_pos_embs))

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_encoder_layers)
    ]

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = [
        _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_decoder_layers)
    ]

    compress_seq = tl.Serial(
        # input:                            #   tok
        tl.Branch([], tl.PaddingMask()),  #   tok mask
        encoder,  #   vec mask
        PickFirst(),  # vec_f mask
        tl.Select([0], n_in=2))  # vec_f

    latent_transition = tl.Serial(
        tl.Parallel([tl.Dense(d_model), tl.Relu()],
                    [tl.Dense(d_model), tl.Relu()]), tl.Add(),
        tl.Residual(
            tl.LayerNorm(),
            tl.Dense(d_model),
            tl.Relu(),
            tl.Dropout(rate=dropout, mode=mode),
            tl.Dense(d_model),
        ))

    pred_valid = tl.Serial(tl.Dense(2), Squeeze(1))

    embed_tgt = tl.Serial(
        # Input                             #  tok_d
        DropLast(mode=mode),  # stok_d
        out_encoder,  # svec_d
    )

    decode_seq = tl.Serial(
        # Input:                                 #  vec_e  tok_d
        tl.Select([1, 0, 1]),  #  tok_d  vec_e tok_d
        tl.Parallel(embed_tgt, [], DropFirst()),  # svec_d  vec_e tok_d'
        ConcatDeEntoEnDe(),  # vec_ed tok_d'
        # Decoder blocks with causal attention
        decoder_blocks,  # vec_ed tok_d'
        tl.LayerNorm(),  # vec_ed tok_d'
        DropFirst(),  #  vec_d tok_d'
        # Map to output vocab.
        tl.Dense(output_vocab_size),  # pred_d tok_d'
    )

    # compress_seq: n_in 1 n_out 1: add mask, encode, pick last hidden
    # latent_transition: n_in 2 n_out 1: s, a -> s_1
    # pred_valid: n_in 1 n_out 1: s_1 -> pred_v
    # decode_seq: n_in 2 n_out 2: copy target, shift right, decode, output

    return tl.Serial(
        #       0      1      2      3      4     5      6 7 8
        # Input:                                #   tok_s  tok_a tok_s1      r      v
        tl.Select([0, 1, 2, 0, 1, 3,
                   4]),  #   tok_s  tok_a tok_s1  tok_s  tok_a     r      v

        # Encode.
        tl.Parallel(
            compress_seq,
            compress_seq),  #   vec_s  vec_a tok_s1  tok_s  tok_a     r      v
        tl.Branch(latent_transition, [], tl.Select(
            [1],
            n_in=2)),  #  vec_s1  vec_s  vec_a tok_s1  tok_s tok_a      r v
        tl.Branch(pred_valid,
                  []),  #  pred_v vec_s1  vec_s  vec_a tok_s1 tok_s  tok_a r v
        # Decode.
        tl.Select([1, 4, 2, 5, 3, 6, 0, 8,
                   7]),  #  vec_s1 tok_s1  vec_s  tok_s  vec_a tok_a pred_v v r
        tl.Parallel(decode_seq, decode_seq, decode_seq
                    ),  # pred_s1 tok_s1 pred_s  tok_s pred_a tok_a pred_v v r
    )
예제 #21
0
파일: hourglass.py 프로젝트: google/trax
    def create_hourglass_valley(
            rest_shorten_factors,
            rest_n_funnel_blocks,  # pylint: disable = invalid-name
            current_total_pooling):
        assert rest_shorten_factors
        assert len(rest_shorten_factors) == len(rest_n_funnel_blocks)

        current_sf = rest_shorten_factors[0]
        current_n_layers = rest_n_funnel_blocks[0]

        shortening_layer = downsampling_fn(
            current_sf,
            d_model,
            is_upsampling=False,
            d_ff=d_ff,
            n_heads=n_heads,
            dropout=dropout,
            dropout_shared_axes=dropout_shared_axes,
            mode=mode,
            ff_activation=ff_activation,
            context_bias_layer=context_bias_layer,
            location_bias_layer=location_bias_layer,
            total_pooling=current_total_pooling,
            resampling_fn=attention_downsampling_fn)

        upsampling_layer = upsampling_fn(
            current_sf,
            d_model=d_model,
            is_upsampling=True,
            d_ff=d_ff,
            n_heads=n_heads,
            dropout=dropout,
            dropout_shared_axes=dropout_shared_axes,
            mode=mode,
            ff_activation=ff_activation,
            context_bias_layer=context_bias_layer,
            location_bias_layer=location_bias_layer,
            total_pooling=current_total_pooling,
            resampling_fn=attention_upsampling_fn)

        if len(rest_shorten_factors) > 1:  # we need to go deeper again
            pre_stage_blocks = create_decoder_blocks(
                current_n_layers, current_total_pooling * current_sf,
                middle_attn_type)

            post_stage_blocks = create_decoder_blocks(
                current_n_layers, current_total_pooling * current_sf,
                middle_attn_type)

            return [
                tl.Dup(),
                tl.ShiftRight(current_sf - 1, mode=mode), shortening_layer,
                pre_stage_blocks, *create_hourglass_valley(
                    rest_shorten_factors[1:], rest_n_funnel_blocks[1:],
                    current_total_pooling * current_sf), post_stage_blocks,
                upsampling_layer,
                tl.LayerNorm(),
                tl.Add()
            ]
        else:
            blocks = create_decoder_blocks(current_n_layers,
                                           current_total_pooling * current_sf,
                                           middle_attn_type)

            return [
                tl.Dup(),
                tl.ShiftRight(current_sf - 1), shortening_layer, blocks,
                upsampling_layer,
                tl.LayerNorm(),
                tl.Add()
            ]