コード例 #1
0
ファイル: reformer.py プロジェクト: srush/trax
  def __init__(self, pre_attention, attention, post_attention):
    self.pre_attention = tl.Serial(
        # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
        tl.Parallel([], tl.Dup()),
        tl.Swap(),
        tl.Parallel(pre_attention, [], []),
    )
    assert hasattr(attention, 'forward_and_backward')
    self.attention = ApplyAttentionWrapper(attention)
    self.post_attention = tl.Parallel(post_attention, [], [])

    layers = [
        self.pre_attention,
        self.attention,
        self.post_attention,
        tl.Parallel(tl.Add(), []),
    ]
    super(ReversibleAttentionHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
    self.reverse_layers = [
        self.pre_attention,
        self.attention,
        self.post_attention,
        self.subtract_top,
    ]
コード例 #2
0
ファイル: reformer.py プロジェクト: syyunn/trax
    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial(
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(residual_layers, [], []),
        )

        layers = [self.compute_residual, tl.Parallel(tl.Add(), [])]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [self.compute_residual, self.subtract_top]
コード例 #3
0
  def __init__(self, residual_layers):
    self.compute_residual = tl.Serial(         # x1_or_y1, x2,           ...
        tl.Select([1, 0, 1]),                  # x2, x1_or_y1, x2,       ...
        tl.Parallel([], [], residual_layers),  # x2, x1_or_y1, residual, ...
        tl.Select([2, 1, 0]),                  # residual, x1_or_y1, x2, ...
    )

    self.n_preserve = self.compute_residual.n_out - 2
    parallel_preserve = [[]] * self.n_preserve

    layers = [
        self.compute_residual,
        tl.Parallel(tl.Add(), *parallel_preserve)
    ]
    super(ReversibleHalfResidual, self).__init__(layers)

    self.subtract_top = tl.Parallel(tl.SubtractTop(), *parallel_preserve)
    self.reverse_layers = [self.compute_residual, self.subtract_top]