Exemple #1
0
    def __init__(self, pre_attention, attention, post_attention):
        self.pre_attention = tl.Serial([
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(pre_attention, [], []),
        ])
        assert hasattr(attention, 'forward_and_backward')
        self.attention = ApplyAttentionWrapper(attention)
        self.post_attention = tl.Parallel(post_attention, [], [])

        layers = [
            self.pre_attention,
            self.attention,
            self.post_attention,
            tl.Parallel(tl.Add(), []),
        ]
        super(ReversibleAttentionHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [
            self.pre_attention,
            self.attention,
            self.post_attention,
            self.subtract_top,
        ]
Exemple #2
0
  def __init__(self, residual_layers):
    self.compute_residual = tl.Serial([
        # TODO(jonni): Rewrite without using Select.
        tl.Select(inputs=('x1_or_y1', 'x2'), output=('x2', 'x1_or_y1', 'x2')),
        tl.Parallel(residual_layers, [], []),
    ])

    layers = [self.compute_residual, tl.Add()]
    super(ReversibleHalfResidual, self).__init__(layers)

    self.subtract_top = tl.SubtractTop()
    self.reverse_layers = [self.compute_residual, self.subtract_top]
Exemple #3
0
    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial([
            # (x1_or_y1, x2) -> (x2, x1_or_y1, x2)
            tl.Parallel([], tl.Dup()),
            tl.Swap(),
            tl.Parallel(residual_layers, [], []),
        ])

        layers = [self.compute_residual, tl.Parallel(tl.Add(), [])]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), [])
        self.reverse_layers = [self.compute_residual, self.subtract_top]