示例#1
0
def gradient_penalty_loss(y_true, y_pred, averaged_samples,
                          gradient_penalty_weight):
    """Calculates the gradient penalty loss for a batch of "averaged" samples.
  In Improved WGANs, the 1-Lipschitz constraint is enforced by adding a term to the
  loss function that penalizes the network if the gradient norm moves away from 1.
  However, it is impossible to evaluate this function at all points in the input
  space. The compromise used in the paper is to choose random points on the lines
  between real and generated samples, and check the gradients at these points. Note
  that it is the gradient w.r.t. the input averaged samples, not the weights of the
  discriminator, that we're penalizing!
  In order to evaluate the gradients, we must first run samples through the generator
  and evaluate the loss. Then we get the gradients of the discriminator w.r.t. the
  input averaged samples. The l2 norm and penalty can then be calculated for this
  gradient.
  Note that this loss function requires the original averaged samples as input, but
  Keras only supports passing y_true and y_pred to loss functions. To get around this,
  we make a partial() of the function with the averaged_samples argument, and use that
  for model training."""
    # first get the gradients:
    #   assuming: - that y_pred has dimensions (batch_size, 1)
    #             - averaged_samples has dimensions (batch_size, nbr_features)
    # gradients afterwards has dimension (batch_size, nbr_features), basically
    # a list of nbr_features-dimensional gradient vectors
    gradients = K.gradients(y_pred, averaged_samples)[0]
    # compute the euclidean norm by squaring ...
    gradients_sqr = K.square(gradients)
    #   ... summing over the rows ...
    gradients_sqr_sum = K.sum(gradients_sqr,
                              axis=np.arange(1, len(gradients_sqr.shape)))
    #   ... and sqrt
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    # compute lambda * (1 - ||grad||)^2 still for each single sample
    gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
    # return the mean as loss over all the batch samples
    return K.mean(gradient_penalty)
 def focal_loss_fixed(y_true, y_pred):
     eps = 1e-6
     alpha = 0.5
     y_pred=K.clip(y_pred,eps,1.-eps)#improve the stability of the focal loss and see issues 1 for more information
     pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
     pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
     return -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1))-K.sum((1-alpha) * K.pow( pt_0, gamma) * K.log(1. - pt_0),axis=-1)
示例#3
0
def custom_mean_squared_loss(y_true, y_pred):
    print(y_true)
    print(y_pred)
    diff = K.abs(y_true - y_pred)
    angle_diff = K.minimum(diff[:, :, 6:], 360 - diff[:, :, 6:])
    error = tf.concat([diff[:, :, :6], angle_diff], axis=-1)

    return K.mean(K.square(error), axis=-1)
示例#4
0
def denseloss(y_true, y_pred, e=1000):
    Le = KTF.mean(KTF.square(y_pred-y_true), axis=-1)
    Lc = get_avgpoolLoss(y_true, y_pred, 1)
    Lc += get_avgpoolLoss(y_true, y_pred, 2)
    Lc += get_avgpoolLoss(y_true, y_pred, 4)
    shp = KTF.get_variable_shape(y_pred)
    Lc = Lc / (shp[1] * shp[2])
    return Le + e * Lc
def dice_coef(y_true, y_pred, smooth=1e-3):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return K.mean(
        (2.0 * intersection + smooth)
        / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    )
示例#6
0
def confidence_reconstruction_loss(y_true, y_pred, mask, num_steps,
                                   gaussian_kernel_size, gaussian_kernel_std):
    mask_blurred = gaussian_utils.blur_mask(mask, num_steps,
                                            gaussian_kernel_size,
                                            gaussian_kernel_std)
    valid_mask = 1 - mask
    diff = K.abs(y_true - y_pred)
    l1 = K.mean(diff * valid_mask + diff * mask_blurred, axis=[1, 2, 3])
    return l1
示例#7
0
 def _compute_cost_huber(self, q, a, r, t, q2):
     preds = slice_tensor_tensor(q, a)
     bootstrap = K.max if not self.use_mean else K.mean
     targets = r + (1 - t) * self.gamma * bootstrap(q2, axis=1)
     err = targets - preds
     cond = K.abs(err) > 1.0
     L2 = 0.5 * K.square(err)
     L1 = (K.abs(err) - 0.5)
     cost = tf.where(cond, L2, L1)
     return K.mean(cost)
示例#8
0
        def loss_func(y_true, y_pred_mll):
            y_true = y_true[:, 0]
            y_pred = y_pred_mll[:, 0]
            mll_pred = y_pred_mll[:, 1]

            mll_loss = K.mean(K.abs(mll_pred - 91.2))
            mll_sigma_loss = K.abs(K.std(mll_pred) - 7.67)

            return binary_crossentropy(
                y_true, y_pred) + c * mll_loss + c * mll_sigma_loss
示例#9
0
def margin_loss(y_true, y_pred):
	"""
	Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
	:param y_true: [None, n_classes]
	:param y_pred: [None, num_capsule]
	:return: a scalar loss value.
	"""
	L = y_true * KTF.square(KTF.maximum(0., 0.9 - y_pred)) + \
			0.5 * (1 - y_true) * KTF.square(KTF.maximum(0., y_pred - 0.1))

	return KTF.mean(KTF.sum(L, 1))
示例#10
0
        def loss_func(y_true, y_pred_mll):
            y_true = y_true[:, 0]
            y_pred = y_pred_mll[:, 0]
            mll_pred = y_pred_mll[:, 1]

            mll_loss = K.mean(K.abs(mll_pred - 91.2))

            #         pseudomll = K.random_normal_variable(shape=(1,1), mean=91.2, scale=2)
            #         mll_loss = K.mean((mll_pred - pseudomll)**2)

            return binary_crossentropy(y_true, y_pred) + c * mll_loss
示例#11
0
def dice_axon(y_true, y_pred, smooth=1e-3):
    """
    Computes the pixel-wise dice myelin coefficient from the prediction tensor outputted by the network.
    :param y_pred: Tensor, the prediction outputed by the network. Shape (N,H,W,C).
    :param y_true: Tensor, the gold standard we work with. Shape (N,H,W,C).
    :return: dice axon coefficient for the current batch.
    """

    y_true_f = K.flatten(y_true[..., 2])
    y_pred_f = K.flatten(y_pred[..., 2])
    intersection = K.sum(y_true_f * y_pred_f)
    return K.mean((2. * intersection + smooth) /
                  (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))
示例#12
0
def f1_loss(y_true, y_pred):
    
    tp = K.sum(K.cast(y_true*y_pred, 'float'), axis=0)
    tn = K.sum(K.cast((1-y_true)*(1-y_pred), 'float'), axis=0)
    fp = K.sum(K.cast((1-y_true)*y_pred, 'float'), axis=0)
    fn = K.sum(K.cast(y_true*(1-y_pred), 'float'), axis=0)

    p = tp / (tp + fp + K.epsilon())
    r = tp / (tp + fn + K.epsilon())

    f1 = 2*p*r / (p+r+K.epsilon())
    f1 = tf.where(tf.is_nan(f1), tf.zeros_like(f1), f1)
    return 1 - K.mean(f1)
def wasserstein_loss(y_true, y_pred, wgan_loss_weight=1.0):
    """Calculates the Wasserstein loss for a sample batch.
  The Wasserstein loss function is very simple to calculate. In a standard GAN, the
  discriminator has a sigmoid output, representing the probability that samples are
  real or generated. In Wasserstein GANs, however, the output is linear with no
  activation function! Instead of being constrained to [0, 1], the discriminator wants
  to make the distance between its output for real and generated samples as
  large as possible.
  The most natural way to achieve this is to label generated samples -1 and real
  samples 1, instead of the 0 and 1 used in normal GANs, so that multiplying the
  outputs by the labels will give you the loss immediately.
  Note that the nature of this loss means that it can be (and frequently will be)
  less than 0."""
    return wgan_loss_weight * K.mean(y_true * y_pred)
示例#14
0
    def loss(self,y_true,y_pred):

        """ executes the categorical cross-entropy

            # Arguments
                y_true : true class values
                y_pred : predicted class values from the model
            # Returns
                ce : mean cross-entropy for the given batch
        """
        y_pred = super().clipping(y_pred)
        ce = -(K.sum((super().c_weights(self.class_weights) * (y_true * K.log(y_pred))),axis=-1))
        ce = K.sum((super().p_weights(self.pixel_weights) * ce),axis=(1,2))
        ce = K.mean(ce,axis=0)
        return ce/1000                                ## scaling down the loss to prevent gradient explosion
示例#15
0
    def loss(self,y_true,y_pred):

        """ executes the focal loss

            # Arguments
                y_true : true class values
                y_pred : predicted class values from the model
            # Returns
                fl : mean focal loss for the given batch
         """
        y_pred = self.clipping(y_pred)
        fl = -(K.sum((self.c_weights(self.class_weights) * K.pow(1.-y_pred,self.gamma) * (y_true * K.log(y_pred))),axis=-1))
        fl = K.sum((self.p_weights(self.pixel_weights) * fl),axis=(1,2))
        fl = K.mean(fl, axis=0)
        return fl/1000                                   ## scaling down the loss to prevent gradient explosion
示例#16
0
def get_target_embeddings(inputs, trg_seq, keyword_op='mean'):
    '''
    Returns the mean of target embeddings to concatenate with the input
    for TC-LSTM.
    trg_seq is a vector of zeros and ones with ones at the target positions
    Embeddings will be returned at target positions and then averaged.
    '''

    units = int(inputs.shape[2])
    trg_seq = RepeatVector(units)(trg_seq)
    trg_seq = Permute([2, 1])(trg_seq)
    trg_embeddings = Multiply()([inputs, trg_seq])
    trg_embeddings = Lambda(lambda x: K.mean(x, axis=1),
                            output_shape=lambda s:
                            (s[0], s[2]))(trg_embeddings)
    return trg_embeddings
示例#17
0
    def loss(self,y_true,y_pred):

        """ executes the dice loss

            # Arguments
                y_true : true class values
                y_pred : predicted class values from the model
            # Returns
                dl : dice loss for the given batch
        """
        y_pred = super().clipping(y_pred)
        intersection = K.sum((super().c_weights(self.class_weights) * y_true * y_pred),axis=-1)
        intersection = K.sum((super().p_weights(self.pixel_weights) * intersection),axis=(1,2))
        union = K.sum( (super().c_weights(self.class_weights)*((y_true*y_true) + (y_pred*y_pred)) ),axis=-1)
        union = K.sum((super().p_weights(self.pixel_weights) * union),axis=(1,2))
        dl = 1. - ((2*intersection)/union)
        return K.mean(dl)
示例#18
0
def grad_cam(model,inputdata):
    inputdata = np.reshape(inputdata, [1, 300, 1])

    #try using book
    model_output = model.output[:,0]
    last_conv_layer= model.get_layer('conv1d_1')
    grads = K.gradients(model_output,last_conv_layer.output)[0]
    pooled_grads = K.mean(grads,axis=(0,1))

    iterate = K.function([model.input],
                         [pooled_grads,last_conv_layer.output])

    pooled_grads_value,conv_layer_output_value = iterate([inputdata])



    for i in range(len(pooled_grads_value)):
        conv_layer_output_value[:,:,i] *= pooled_grads_value[i]

    grad_cam = np.average(conv_layer_output_value, 0)


    cam_data = []
    for i in range(len(grad_cam)):
        cam_data.append(np.average(grad_cam[i,:]))
    cam_data = np.reshape(cam_data,[1,len(cam_data)])

    from scipy.signal import savgol_filter
    #test_cam2 = savgol_filter(cam_data, 3,0)

    test_cam2 = np.resize(cam_data,[1,300])

    """
    fig = plt.figure(figsize=(20, 10))
    ax0 = plt.subplot2grid((1, 1), (0, 0), colspan=1)
    plt.yticks(fontsize=15)
    ax0.plot(inputdata.flatten(),c='blue')
    ax0_2 = ax0.twinx()
    ax0_2.imshow(test_cam2,cmap='gist_heat',aspect='auto',alpha=0.4)

    #plt.show()
    """

    return test_cam2
示例#19
0
文件: train.py 项目: aasensio/DNHazel
    def mean_log_Gaussian_like(self, y_true, parameters):
        """Mean Log Gaussian Likelihood distribution
        Note: The 'c' variable is obtained as global variable
        """
        components = ktf.reshape(parameters,[-1, 2*9 + 1, self.n_classes])
        
        mu = components[:, 0:9, :]
        sigma = components[:, 9:18, :]
        alpha = components[:, 18, :]

        alpha = ktf.softmax(ktf.clip(alpha,1e-8,1.))
        
        exponent = ktf.log(alpha) - .5 * float(self.c) * ktf.log(2 * np.pi) \
            - ktf.sum(ktf.log(sigma), axis=1) \
            - ktf.sum((ktf.expand_dims(y_true,2) - mu)**2 / (2*(sigma)**2), axis=1)
        
        log_gauss = log_sum_exp(exponent, axis=1)
        res = - ktf.mean(log_gauss)
        return res        
示例#20
0
def norm_layer(x, axis=1):
    return (x - K.mean(x, axis=axis, keepdims=True)) / K.std(
        x, axis=axis, keepdims=True)
示例#21
0
        def c_loss(ytrue, ypred):
            loss = mae(ytrue, ypred)
            print(type(loss))
            loss += self.loss_weight * K.square((K.mean(ypred) - self.center))

            return loss
示例#22
0
def wgan_d_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)
示例#23
0
def r2_metrics(y_true, y_pred):
    ssr = K.mean((y_pred - y_true)**2, axis=0)
    sst = K.mean((y_pred - K.mean(y_true, axis=0))**2, axis=0)
    return 1 - K.mean(ssr / sst, axis=-1)
示例#24
0
 def loss_func(y_true, y_pred):
     return ktf.sqrt(ktf.mean(ktf.square(y_pred - y_true)))
class Metrics:

    __weighted_method = lambda self, x, y, string, w: (K.sum(x) / K.sum(
        y)) if string == 'inter' else (K.sum(w * x) / K.sum(w * y))
    __avg_method = lambda self, x, y, string, w: K.mean(
        x / y) if string == 'intra' else self.__weighted_method(
            x, y, string, w)

    def __metrics_base(self, y_true, y_pred):
        """ Base for all the metrics defined below """
        y_true, y_pred = K.flatten(tf.math.argmax(y_true, axis=-1)), K.flatten(
            tf.math.argmax(y_pred, axis=-1))
        con_mat = K.cast(tf.math.confusion_matrix(y_true, y_pred), K.floatx())
        correct = tf.linalg.diag_part(con_mat)
        total = K.sum(con_mat, axis=-1)
        return correct, total, con_mat

    def accuracy(self, y_true, y_pred):
        """ computes the accuracy

            # Arguments
                y_true : target value
                y_pred : predicted class value
            # Returns
                acc : overall accuracy
        """
        correct, total, _ = self.__metrics_base(y_true, y_pred)
        return (K.sum(correct) / K.sum(total))

    def IoU(self, y_true, y_pred, average='inter', weights=None):
        """ Intersection over Union , IoU = A^B/(A U B - A^B)
           Computes the percentage overlap with the target image.

            # Arguments
                y_true : target value
                y_pred : predicted class value
                average : 'inter' --> computes the IoU score overall  'intra' --> computes the score for each calss and computes the average
                        'weighted' --> computes the weighted average , useful for imabalanced class.
                weights :  only if average is specified 'weighted', weights for the respective classes.
            # Returns
                IoU score
        """
        _, _, con_mat = self.__metrics_base(y_true, y_pred)
        intersection = tf.linalg.diag_part(con_mat)
        ground_truth_set = K.sum(con_mat, axis=1)
        predicted_set = K.sum(con_mat, axis=0)
        union = ground_truth_set + predicted_set - intersection
        return self.__avg_method(intersection, union, average, weights)

    def recall(self, y_true, y_pred, average='inter', weights=None):
        """ Computes the recall score over each given class and gives the overall score.  recall = TP/TP+FN

            # Arguments
                y_true : target value
                y_pred : predicted class value
                average : 'inter' --> computes the recall score overall  'intra' --> computes the score for each calss and computes the average
                        'weighted' --> computes the weighted average , useful for imabalanced class.
                weights :  only if average is specified 'weighted', weights for the respective classes.
            # Returns
                recall score
        """
        correct, total, _ = self.__metrics_base(y_true, y_pred)
        return self.__avg_method(correct, total, average, weights)

    def precision(self, y_true, y_pred, average='inter', weights=None):
        """ Computes the precision over each given class and returns the overall score.  precision = TP/TP+FP

            # Arguments
                y_true : target value
                y_pred : predicted class value
                average : 'inter' --> computes the precision score overall  'intra' --> computes the score for each calss and computes the average
                        'weighted' --> computes the weighted average , useful for imabalanced class.
                weights :  only if average is specified 'weighted', weights for the respective classes.
            # Returns
                precision score
        """
        correct, _, con_mat = self.__metrics_base(y_true, y_pred)
        total = K.sum(con_mat, axis=0)
        return self.__avg_method(correct, total, average, weights)

    def f1score(self, y_true, y_pred, average='inter', weights=None):
        """ Computes the f1 score over each given class and returns the overall score.  f1 = (2*precision*recall)/(precision+recall)

            # Arguments
                y_true : target value
                y_pred : predicted class value
                average : 'inter' --> computes the f1 score overall  'intra' --> computes the score for each calss and computes the average
                            'weighted' --> computes the weighted average , useful for imabalanced class.
                weights :  only if average is specified 'weighted', weights for the respective classes.
            # Returns
                 f1 score
        """
        precision = self.precision(y_true, y_pred, average, weights)
        recall = self.recall(y_true, y_pred, average, weights)
        return ((2 * precision * recall) / (precision + recall))

    def dice_coeffiecient(self, y_true, y_pred, average='inter', weights=None):
        """ Computes the dice score over each given class and returns the overall score.

                # Arguments
                    y_true : target value
                    y_pred : predicted class value
                    average : 'inter' --> computes the dice score overall  'intra' --> computes the score for each calss and computes the average
                                    'weighted' --> computes the weighted average , useful for imabalanced class.
                    weights :  only if average is specified 'weighted', weights for the respective classes.
                # Returns
                    dice score
                """

        y_pred = focal_loss.clipping(y_pred)
        intersection = 2 * K.sum((y_true * y_pred), axis=(0, 1, 2))
        union = K.sum((y_true * y_true) + (y_pred * y_pred), axis=(0, 1, 2))
        return self.__avg_method(intersection, union, average, weights)
示例#26
0
文件: seq2seq.py 项目: yynjupt/Yuan
def root_mean_squared_error(y_true, y_pred):
    return KTF.sqrt(KTF.mean(KTF.square(y_pred - y_true)))
示例#27
0
文件: began.py 项目: aasensio/DNHazel
    def gan(self):
    # initialize a GAN trainer

    # this is the fastest way to train a GAN in Keras
    # two models are updated simutaneously in one pass

        noise = Input(shape=self.generator.input_shape[1:])
        real_data = Input(shape=self.discriminator.input_shape[1:])

        generated = self.generator(noise)
        gscore = self.discriminator(generated)
        rscore = self.discriminator(real_data)

        def log_eps(i):
            return K.log(i+1e-11)

        # single side label smoothing: replace 1.0 with 0.9
        dloss = - K.mean(log_eps(1-gscore) + .1 * log_eps(1-rscore) + .9 * log_eps(rscore))
        gloss = - K.mean(log_eps(gscore))

        Adam = tf.train.AdamOptimizer

        lr,b1 = 1e-4,.2 # otherwise won't converge.
        optimizer = Adam(lr)

        grad_loss_wd = optimizer.compute_gradients(dloss, self.discriminator.trainable_weights)
        update_wd = optimizer.apply_gradients(grad_loss_wd)

        grad_loss_wg = optimizer.compute_gradients(gloss, self.generator.trainable_weights)
        update_wg = optimizer.apply_gradients(grad_loss_wg)

        def get_internal_updates(model):
            # get all internal update ops (like moving averages) of a model
            inbound_nodes = model.inbound_nodes
            input_tensors = []
            for ibn in inbound_nodes:
                input_tensors+= ibn.input_tensors
            updates = [model.get_updates_for(i) for i in input_tensors]
            return updates

        other_parameter_updates = [get_internal_updates(m) for m in [self.discriminator,self.generator]]
        # those updates includes batch norm.

        print('other_parameter_updates for the models(mainly for batch norm):')
        print(other_parameter_updates)

        train_step = [update_wd, update_wg, other_parameter_updates]
        losses = [dloss,gloss]

        learning_phase = K.learning_phase()

        def gan_feed(sess,batch_image,z_input):
            # actual GAN trainer
            nonlocal train_step,losses,noise,real_data,learning_phase

            res = sess.run([train_step,losses],feed_dict={
            noise:z_input,
            real_data:batch_image,
            learning_phase:True,
            # Keras layers needs to know whether
            # this run is training or testring (you know, batch norm and dropout)
            })

            loss_values = res[1]
            return loss_values #[dloss,gloss]

        return gan_feed
示例#28
0
def get_avgpoolLoss(y_true, y_pred, k):
    loss = KTF.mean((abs(AveragePooling2D(pool_size=(k, k), strides=(1, 1))(y_true) -
                AveragePooling2D(pool_size=(k, k), strides=(1, 1))(y_pred)))) / k
    return loss
def reconstruction_loss(y_true, y_pred):
    diff = K.abs(y_pred - y_true)
    l1 = K.mean(diff, axis=[1, 2, 3])
    return l1
示例#30
0
文件: DCGAN.py 项目: USTB-ML/code_hub
def wasserstein(y_true, y_pred):
    return K.mean(y_true * y_pred)
示例#31
0
 def normalize(self, x):
     # utility function to normalize a tensor by its L2 norm
     return x / (KTF.sqrt(KTF.mean(KTF.square(x)) + 1e-5))
示例#32
0
    def step(self, a, states):
        r_tm1 = states[:self.nb_layers]
        c_tm1 = states[self.nb_layers:2 * self.nb_layers]
        e_tm1 = states[2 * self.nb_layers:3 * self.nb_layers]

        if self.extrap_start_time is not None:
            t = states[-1]
            a = K.switch(
                t >= self.t_extrap, states[-2], a
            )  # if past self.extrap_start_time, the previous prediction will be treated as the actual

        c = []
        r = []
        e = []

        # Update R units starting from the top
        for l in reversed(range(self.nb_layers)):
            inputs = [r_tm1[l], e_tm1[l]]
            if l < self.nb_layers - 1:
                inputs.append(r_up)

            inputs = K.concatenate(inputs, axis=self.channel_axis)
            i = self.conv_layers['i'][l].call(inputs)
            f = self.conv_layers['f'][l].call(inputs)
            o = self.conv_layers['o'][l].call(inputs)
            _c = f * c_tm1[l] + i * self.conv_layers['c'][l].call(inputs)
            _r = o * self.LSTM_activation(_c)
            c.insert(0, _c)
            r.insert(0, _r)

            if l > 0:
                r_up = self.upsample.call(_r)

        # Update feedforward path starting from the bottom
        for l in range(self.nb_layers):
            ahat = self.conv_layers['ahat'][l].call(r[l])
            if l == 0:
                ahat = K.minimum(ahat, self.pixel_max)
                frame_prediction = ahat

            # compute errors
            e_up = self.error_activation(ahat - a)
            e_down = self.error_activation(a - ahat)

            e.append(K.concatenate((e_up, e_down), axis=self.channel_axis))

            if self.output_layer_num == l:
                if self.output_layer_type == 'A':
                    output = a
                elif self.output_layer_type == 'Ahat':
                    output = ahat
                elif self.output_layer_type == 'R':
                    output = r[l]
                elif self.output_layer_type == 'E':
                    output = e[l]

            if l < self.nb_layers - 1:
                a = self.conv_layers['a'][l].call(e[l])
                a = self.pool.call(a)  # target for next layer

        if self.output_layer_type is None:
            if self.output_mode == 'prediction':
                output = frame_prediction
            else:
                for l in range(self.nb_layers):
                    layer_error = K.mean(K.batch_flatten(e[l]),
                                         axis=-1,
                                         keepdims=True)
                    all_error = layer_error if l == 0 else K.concatenate(
                        (all_error, layer_error), axis=-1)
                if self.output_mode == 'error':
                    output = all_error
                else:
                    output = K.concatenate(
                        (K.batch_flatten(frame_prediction), all_error),
                        axis=-1)

        states = r + c + e
        if self.extrap_start_time is not None:
            states += [frame_prediction, t + 1]
        return output, states