def connect_xz(self, x):
        def lambda_exp(x):
            return keras.backend.exp(x)

        def lambda_sum(x):
            return keras.backend.sum(x[0], axis=1,
                                     keepdims=True) + keras.backend.sum(
                                         x[1], axis=1, keepdims=True)

        x1 = x[0]
        x2 = x[1]
        self.input_x1 = x1
        self.input_x2 = x2

        y1 = x1
        self.Sxy_layer = connect(x1, self.S1)
        self.Txy_layer = connect(x1, self.T1)
        prodx = keras.layers.Multiply()(
            [x2, keras.layers.Lambda(lambda_exp)(self.Sxy_layer)])
        y2 = keras.layers.Add()([prodx, self.Txy_layer])

        self.output_z2 = y2
        self.Syz_layer = connect(y2, self.S2)
        self.Tyz_layer = connect(y2, self.T2)
        prody = keras.layers.Multiply()(
            [y1, keras.layers.Lambda(lambda_exp)(self.Syz_layer)])
        self.output_z1 = keras.layers.Add()([prody, self.Tyz_layer])

        # log det(dz/dx)
        self.log_det_xz = keras.layers.Lambda(lambda_sum)(
            [self.Sxy_layer, self.Syz_layer])

        return [self.output_z1, self.output_z2
                ] + x[2:]  # append other layers if there are any
    def connect_zx(self, z):
        def lambda_negexp(x):
            return keras.backend.exp(-x)
        def lambda_negsum(x):
            return keras.backend.sum(-x[0] - x[1], axis=1, keepdims=True)

        z1 = z[0]
        z2 = z[1]
        self.input_z1 = z1
        self.input_z2 = z2

        y2 = z2
        self.Szy_layer = connect(z2, self.S2)
        self.Tzy_layer = connect(z2, self.T2)
        z1_m_Tz2 = keras.layers.Subtract()([z1, self.Tzy_layer])
        y1 = keras.layers.Multiply()([z1_m_Tz2, keras.layers.Lambda(lambda_negexp)(self.Szy_layer)])

        self.output_x1 = y1
        self.Syx_layer = connect(y1, self.S1)
        self.Tyx_layer = connect(y1, self.T1)
        y2_m_Ty1 = keras.layers.Subtract()([y2, self.Tyx_layer])
        self.output_x2 = keras.layers.Multiply()([y2_m_Ty1, keras.layers.Lambda(lambda_negexp)(self.Syx_layer)])

        # log det(dx/dz)
        # TODO: check Jacobian
        self.log_det_zx = keras.layers.Lambda(lambda_negsum)([self.Szy_layer, self.Syx_layer])

        return [self.output_x1, self.output_x2]
    def connect_zx(self, z):
        z1 = z[0]
        z2 = z[1]
        self.input_z1 = z1
        self.input_z2 = z2

        # first stage forward
        y1 = z1
        y2 = keras.layers.Add()([z2, connect(z1, self.M1)])
        # second stage forward
        x2 = y2
        x1 = keras.layers.Add()([y1, connect(y2, self.M2)])

        return [x1, x2] + z[2:]  # append other layers if there are any
    def connect_xz(self, x):
        x1 = x[0]
        x2 = x[1]
        self.input_x1 = x1
        self.input_x2 = x2

        # first stage backward
        y2 = x2
        y1 = keras.layers.Subtract()([x1, connect(x2, self.M2)])
        # second stage backward
        z1 = y1
        z2 = keras.layers.Subtract()([y2, connect(y1, self.M1)])

        return [z1, z2] + x[2:]  # append other layers if there are any