def __init__(self, pre_attention, attention, post_attention): self.pre_attention = tl.Serial( # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(pre_attention, [], []), ) assert hasattr(attention, 'forward_and_backward') self.attention = ApplyAttentionWrapper(attention) self.post_attention = tl.Parallel(post_attention, [], []) layers = [ self.pre_attention, self.attention, self.post_attention, tl.Parallel(tl.Add(), []), ] super(ReversibleAttentionHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [ self.pre_attention, self.attention, self.post_attention, self.subtract_top, ]
def __init__(self, residual_layers): self.compute_residual = tl.Serial( # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(residual_layers, [], []), ) layers = [self.compute_residual, tl.Parallel(tl.Add(), [])] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [self.compute_residual, self.subtract_top]
def __init__(self, residual_layers): self.compute_residual = tl.Serial( # x1_or_y1, x2, ... tl.Select([1, 0, 1]), # x2, x1_or_y1, x2, ... tl.Parallel([], [], residual_layers), # x2, x1_or_y1, residual, ... tl.Select([2, 1, 0]), # residual, x1_or_y1, x2, ... ) self.n_preserve = self.compute_residual.n_out - 2 parallel_preserve = [[]] * self.n_preserve layers = [ self.compute_residual, tl.Parallel(tl.Add(), *parallel_preserve) ] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), *parallel_preserve) self.reverse_layers = [self.compute_residual, self.subtract_top]