class CollectOutputAndTarget(Callback):
    def __init__(self):
        super(CollectOutputAndTarget, self).__init__()
        self.inputs = []  # collect x_input batches
        self.targets = []  # collect y_true batches
        self.outputs = []  # collect y_pred batches

        # the shape of these 2 variables will change according to batch shape
        # to handle the "last batch", specify `validate_shape=False`
        self.var_x_input = tf.Variable(0., validate_shape=False)
        self.var_y_true = tf.Variable(0., validate_shape=False)
        self.var_y_pred = tf.Variable(0., validate_shape=False)

        self.train_plotObj = Plot('train', title='Train Predictions')

    def on_batch_end(self, batch, logs=None):
        # evaluate the variables and save them into lists
        x_input = K.eval(self.var_x_input)
        y_true = K.eval(self.var_y_true)
        y_pred = K.eval(self.var_y_pred)

        # print(y_true)
        # print(y_pred)
        # print(mat2uint8(y_true))
        # print(mat2uint8(y_pred))

        if showImages:
            self.train_plotObj.showTrainResults(x_input[0, :, :, :],
                                                y_true[0, :, :,
                                                       0], y_pred[0, :, :, 0])
class CollectOutputAndTarget(Callback):
    def __init__(self):
        super(CollectOutputAndTarget, self).__init__()
        self.inputs = []  # collect x_input batches
        self.targets = []  # collect y_true batches
        self.outputs = []  # collect y_pred batches

        # the shape of these 2 variables will change according to batch shape
        # to handle the "last batch", specify `validate_shape=False`
        self.var_x_input = tf.Variable(0., validate_shape=False)
        self.var_y_true = tf.Variable(0., validate_shape=False)
        self.var_y_pred = tf.Variable(0., validate_shape=False)

        self.train_plotObj = Plot('train', title='Train Predictions')

    def on_batch_end(self, batch, logs=None):
        # evaluate the variables and save them into lists
        x_input = K.eval(self.var_x_input)
        y_true = K.eval(self.var_y_true)
        y_pred = K.eval(self.var_y_pred)

        # print(y_true)
        # print(y_pred)
        # print(mat2uint8(y_true))
        # print(mat2uint8(y_pred))

        if showImages:
            self.train_plotObj.showTrainResults(x_input[0, :, :, :], y_true[0, :, :, 0], y_pred[0, :, :, 0])