def __init__(self, pre_attention, attention, post_attention): self.pre_attention = tl.Serial([ # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(pre_attention, [], []), ]) assert hasattr(attention, 'forward_and_backward') self.attention = ApplyAttentionWrapper(attention) self.post_attention = tl.Parallel(post_attention, [], []) layers = [ self.pre_attention, self.attention, self.post_attention, tl.Parallel(tl.Add(), []), ] super(ReversibleAttentionHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [ self.pre_attention, self.attention, self.post_attention, self.subtract_top, ]
def __init__(self, residual_layers): self.compute_residual = tl.Serial([ # TODO(jonni): Rewrite without using Select. tl.Select(inputs=('x1_or_y1', 'x2'), output=('x2', 'x1_or_y1', 'x2')), tl.Parallel(residual_layers, [], []), ]) layers = [self.compute_residual, tl.Add()] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.SubtractTop() self.reverse_layers = [self.compute_residual, self.subtract_top]
def __init__(self, residual_layers): self.compute_residual = tl.Serial([ # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(residual_layers, [], []), ]) layers = [self.compute_residual, tl.Parallel(tl.Add(), [])] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [self.compute_residual, self.subtract_top]