def gram_matrix(input_tensor):
    assert K.ndim(input_tensor) == 3

    features = K.batch_flatten(K.permute_dimensions(input_tensor, (2, 0, 1)))
    gram = K.dot(features, K.transpose(features))
    return gram
Esempio n. 2
0
    def call(self, x, mask=None):
        # TODO: validate input shape

        assert (len(x) == 3)
        L_flat = x[0]
        mu = x[1]
        a = x[2]

        if self.mode == 'full':
            # Create L and L^T matrix, which we use to construct the positive-definite matrix P.
            L = None
            LT = None
            if K.backend() == 'theano':
                import theano.tensor as T
                import theano

                def fn(x, L_acc, LT_acc):
                    x_ = K.zeros((self.nb_actions, self.nb_actions))
                    x_ = T.set_subtensor(x_[np.tril_indices(self.nb_actions)], x)
                    diag = K.exp(T.diag(x_)) + K.epsilon()
                    x_ = T.set_subtensor(x_[np.diag_indices(self.nb_actions)], diag)
                    return x_, x_.T

                outputs_info = [
                    K.zeros((self.nb_actions, self.nb_actions)),
                    K.zeros((self.nb_actions, self.nb_actions)),
                ]
                results, _ = theano.scan(fn=fn, sequences=L_flat, outputs_info=outputs_info)
                L, LT = results
            elif K.backend() == 'tensorflow':
                import tensorflow as tf

                # Number of elements in a triangular matrix.
                nb_elems = (self.nb_actions * self.nb_actions + self.nb_actions) // 2

                # Create mask for the diagonal elements in L_flat. This is used to exponentiate
                # only the diagonal elements, which is done before gathering.
                diag_indeces = [0]
                for row in range(1, self.nb_actions):
                    diag_indeces.append(diag_indeces[-1] + (row + 1))
                diag_mask = np.zeros(1 + nb_elems)  # +1 for the leading zero
                diag_mask[np.array(diag_indeces) + 1] = 1
                diag_mask = K.variable(diag_mask)

                # Add leading zero element to each element in the L_flat. We use this zero
                # element when gathering L_flat into a lower triangular matrix L.
                nb_rows = tf.shape(L_flat)[0]
                zeros = tf.expand_dims(tf.tile(K.zeros((1,)), [nb_rows]), 1)
                try:
                    # Old TF behavior.
                    L_flat = tf.concat(1, [zeros, L_flat])
                except (TypeError, ValueError):
                    # New TF behavior
                    L_flat = tf.concat([zeros, L_flat], 1)

                # Create mask that can be used to gather elements from L_flat and put them
                # into a lower triangular matrix.
                tril_mask = np.zeros((self.nb_actions, self.nb_actions), dtype='int32')
                tril_mask[np.tril_indices(self.nb_actions)] = range(1, nb_elems + 1)

                # Finally, process each element of the batch.
                init = [
                    K.zeros((self.nb_actions, self.nb_actions)),
                    K.zeros((self.nb_actions, self.nb_actions)),
                ]

                def fn(a, x):
                    # Exponentiate everything. This is much easier than only exponentiating
                    # the diagonal elements, and, usually, the action space is relatively low.
                    x_ = K.exp(x) + K.epsilon()
                    # Only keep the diagonal elements.
                    x_ *= diag_mask
                    # Add the original, non-diagonal elements.
                    x_ += x * (1. - diag_mask)
                    # Finally, gather everything into a lower triangular matrix.
                    L_ = tf.gather(x_, tril_mask)
                    return [L_, tf.transpose(L_)]

                tmp = tf.scan(fn, L_flat, initializer=init)
                if isinstance(tmp, (list, tuple)):
                    # TensorFlow 0.10 now returns a tuple of tensors.
                    L, LT = tmp
                else:
                    # Old TensorFlow < 0.10 returns a shared tensor.
                    L = tmp[:, 0, :, :]
                    LT = tmp[:, 1, :, :]
            else:
                raise RuntimeError('Unknown Keras backend "{}".'.format(K.backend()))
            assert L is not None
            assert LT is not None
            P = K.batch_dot(L, LT)
        elif self.mode == 'diag':
            if K.backend() == 'theano':
                import theano.tensor as T
                import theano

                def fn(x, P_acc):
                    x_ = K.zeros((self.nb_actions, self.nb_actions))
                    x_ = T.set_subtensor(x_[np.diag_indices(self.nb_actions)], x)
                    return x_

                outputs_info = [
                    K.zeros((self.nb_actions, self.nb_actions)),
                ]
                P, _ = theano.scan(fn=fn, sequences=L_flat, outputs_info=outputs_info)
            elif K.backend() == 'tensorflow':
                import tensorflow as tf

                # Create mask that can be used to gather elements from L_flat and put them
                # into a diagonal matrix.
                diag_mask = np.zeros((self.nb_actions, self.nb_actions), dtype='int32')
                diag_mask[np.diag_indices(self.nb_actions)] = range(1, self.nb_actions + 1)

                # Add leading zero element to each element in the L_flat. We use this zero
                # element when gathering L_flat into a lower triangular matrix L.
                nb_rows = tf.shape(L_flat)[0]
                zeros = tf.expand_dims(tf.tile(K.zeros((1,)), [nb_rows]), 1)
                try:
                    # Old TF behavior.
                    L_flat = tf.concat(1, [zeros, L_flat])
                except (TypeError, ValueError):
                    # New TF behavior
                    L_flat = tf.concat([zeros, L_flat], 1)

                # Finally, process each element of the batch.
                def fn(a, x):
                    x_ = tf.gather(x, diag_mask)
                    return x_

                P = tf.scan(fn, L_flat, initializer=K.zeros((self.nb_actions, self.nb_actions)))
            else:
                raise RuntimeError('Unknown Keras backend "{}".'.format(K.backend()))
        assert P is not None
        assert K.ndim(P) == 3

        # Combine a, mu and P into a scalar (over the batches). What we compute here is
        # -.5 * (a - mu)^T * P * (a - mu), where * denotes the dot-product. Unfortunately
        # TensorFlow handles vector * P slightly suboptimal, hence we convert the vectors to
        # 1xd/dx1 matrices and finally flatten the resulting 1x1 matrix into a scalar. All
        # operations happen over the batch size, which is dimension 0.
        prod = K.batch_dot(K.expand_dims(a - mu, 1), P)
        prod = K.batch_dot(prod, K.expand_dims(a - mu, -1))
        A = -.5 * K.batch_flatten(prod)
        assert K.ndim(A) == 2
        return A
def complex_standardization(input_centred, Vrr, Vii, Vri, layernorm = False, axis = -1):
    """Complex Standardization of input
    
    Arguments:
        input_centred -- Input Tensor
        Vrr -- Real component of covariance matrix V
        Vii -- Imaginary component of covariance matrix V
        Vri -- Non-diagonal component of covariance matrix V
    
    Keyword Arguments:
        layernorm {bool} -- Normalization (default: {False})
        axis {int} -- Axis for Standardization (default: {-1})
    
    Raises:
        ValueError: Mismatched dimensoins
    
    Returns:
        Complex standardized input

    We require the covariance matrix's inverse square root. That first
    requires square rooting, followed by inversion (I do this in that order
    because during the computation of square root we compute the determinant
    we'll need for inversion as well).

    The square root matrix could now be explicitly formed as
          [ Vrr+s Vri   ]
    (1/t) [ Vir   Vii+s ]
    https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
    but we don't need to do this immediately since we can also simultaneously
    invert. We can do this because we've already computed the determinant of
    the square root matrix, and can thus invert it using the analytical
    solution for 2x2 matrices
         [ A B ]             [  D  -B ]
    inv( [ C D ] ) = (1/det) [ -C   A ]
    http://mathworld.wolfram.com/MatrixInverse.html
    Thus giving us
              [  Vii+s  -Vri   ]
    (1/s)(1/t)[ -Vir     Vrr+s ]
    So we proceed as follows:

    And we have computed the inverse square root matrix W = sqrt(V)!
    Normalization. We multiply, x_normalized = W.x.

    The returned result will be a complex standardized input
    where the real and imaginary parts are obtained as follows:
    x_real_normed = Wrr * x_real_centred + Wri * x_imag_centred
    x_imag_normed = Wri * x_real_centred + Wii * x_imag_centred

      Wrr * x_real_centered | Wii * x_imag_centered
    + Wri * x_imag_centered | Wri * x_real_centered
    -----------------------------------------------
    = output
    """

    """
    [Batch_size, height, width, channels]
    ndim(input_centred) == 4
    shape(input_centred) == [2, 256, 32, 16] --> [2, 256, 32, 8] is real, [2, 256, 32, 8] is imag
    shape(input_centred)[axis = -1] == 16
    
    variances_broadcast는 채널의 갯수에 의존
    """
    ndim      = K.ndim(input_centred) 
    input_dim = K.shape(input_centred)[axis] // 2

    variances_broadcast = [1] * ndim
    variances_broadcast[axis] = input_dim

    if layernorm:
        variances_broadcast[0] = K.shape(input_centred)[0]

    tau   = Vrr + Vii
    delta = (Vrr * Vii) - (Vri ** 2)

    s = K.sqrt(delta)
    t = K.sqrt(tau + 2 * s)

    inverse_st = 1.0 / (s * t)
    Wrr = (Vii + s) * inverse_st
    Wii = (Vrr + s) * inverse_st
    Wri = -Vri * inverse_st

    broadcast_Wrr = K.reshape(Wrr, variances_broadcast)
    broadcast_Wri = K.reshape(Wri, variances_broadcast)
    broadcast_Wii = K.reshape(Wii, variances_broadcast)

    cat_W_4_real = K.concatenate([broadcast_Wrr, broadcast_Wii], axis=axis)
    cat_W_4_imag = K.concatenate([broadcast_Wri, broadcast_Wri], axis=axis)

    'for Conv2D'
    centred_real = input_centred[:, :, :, :input_dim]
    centred_imag = input_centred[:, :, :, input_dim:]

    rolled_input = K.concatenate([centred_imag, centred_real], axis=axis)

    # wrr real + wri imag, wri real + wii imag
    output = cat_W_4_real * input_centred + cat_W_4_imag * rolled_input
    return output
Esempio n. 4
0
def modal_dot(a, b, transpose_a=False, transpose_b=False):
    """
    Computes the matrix multiplication of a and b, handling the data modes
    automatically.

    This is a wrapper to standard matmul operations, for a and b with rank 2
    or 3, that:

    - Supports automatic broadcasting of the "batch" dimension if the two inputs
    have different ranks.
    - Supports any combination of dense and sparse inputs.

    This op is useful for multiplying matrices that represent batches of graphs
    in the different modes, for which the adjacency matrices may or may not be
    sparse and have different ranks from the node attributes.

    Additionally, it can also support the case where we have many adjacency
    matrices and only one graph signal (which is uncommon, but may still happen).

    If you know a-priori the type and shape of the inputs, it may be faster to
    use the built-in functions of TensorFlow directly instead.

    Examples:

        - `a` rank 2, `b` rank 2 -> `a @ b`
        - `a` rank 3, `b` rank 3 -> `[a[i] @ b[i] for i in range(len(a))]`
        - `a` rank 2, `b` rank 3 -> `[a @ b[i] for i in range(len(b))]`
        - `a` rank 3, `b` rank 2 -> `[a[i] @ b for i in range(len(a))]`

    :param a: Tensor or SparseTensor with rank 2 or 3;
    :param b: Tensor or SparseTensor with rank 2 or 3;
    :param transpose_a: transpose the innermost 2 dimensions of `a`;
    :param transpose_b: transpose the innermost 2 dimensions of `b`;
    :return: Tensor or SparseTensor with rank = max(rank(a), rank(b)).
    """
    a_ndim = K.ndim(a)
    b_ndim = K.ndim(b)
    assert a_ndim in (2, 3), "Expected a of rank 2 or 3, got {}".format(a_ndim)
    assert b_ndim in (2, 3), "Expected b of rank 2 or 3, got {}".format(b_ndim)

    if transpose_a:
        perm = None if a_ndim == 2 else (0, 2, 1)
        a = ops.transpose(a, perm)
    if transpose_b:
        perm = None if b_ndim == 2 else (0, 2, 1)
        b = ops.transpose(b, perm)

    if a_ndim == b_ndim:
        # ...ij,...jk->...ik
        return dot(a, b)
    elif a_ndim == 2:
        # ij,bjk->bik
        return mixed_mode_dot(a, b)
    else:  # a_ndim == 3
        # bij,jk->bik
        if not K.is_sparse(a) and not K.is_sparse(b):
            # Immediately fallback to standard dense matmul, no need to reshape
            return tf.matmul(a, b)

        # If either input is sparse, we use dot(a, b)
        # This implementation is faster than using rank 3 sparse matmul with tfsp
        a_shape = tf.shape(a)
        b_shape = tf.shape(b)
        a_flat = ops.reshape(a, (-1, a_shape[2]))
        output = dot(a_flat, b)
        return ops.reshape(output, (-1, a_shape[1], b_shape[1]))
Esempio n. 5
0
def ComplexBN(input_centred, Vrr, Vii, Vri, beta,
               gamma_rr, gamma_ri, gamma_ii, scale=True,
               center=True, layernorm=False, axis=-1):

    ndim = K.ndim(input_centred)
    input_dim = tf.shape(input_centred)[axis] // 2
    if scale:
        gamma_broadcast_shape = [1] * ndim
        gamma_broadcast_shape[axis] = input_dim
    if center:
        broadcast_beta_shape = [1] * ndim
        broadcast_beta_shape[axis] = input_dim * 2

    if scale:
        standardized_output = complex_standardization(
            input_centred, Vrr, Vii, Vri,
            layernorm,
            axis=axis
        )

        # Now we perform th scaling and Shifting of the normalized x using
        # the scaling parameter
        #           [  gamma_rr gamma_ri  ]
        #   Gamma = [  gamma_ri gamma_ii  ]
        # and the shifting parameter
        #    Beta = [beta_real beta_imag].T
        # where:
        # x_real_BN = gamma_rr * x_real_normed + gamma_ri * x_imag_normed + beta_real
        # x_imag_BN = gamma_ri * x_real_normed + gamma_ii * x_imag_normed + beta_imag
        
        broadcast_gamma_rr = tf.reshape(gamma_rr, gamma_broadcast_shape)
        broadcast_gamma_ri = tf.reshape(gamma_ri, gamma_broadcast_shape)
        broadcast_gamma_ii = tf.reshape(gamma_ii, gamma_broadcast_shape)

        cat_gamma_4_real = tf.concat([broadcast_gamma_rr, broadcast_gamma_ii], axis=axis)
        cat_gamma_4_imag = tf.concat([broadcast_gamma_ri, broadcast_gamma_ri], axis=axis)
        if (axis == 1 and ndim != 3) or ndim == 2:
            centred_real = standardized_output[:, :input_dim]
            centred_imag = standardized_output[:, input_dim:]
        elif ndim == 3:
            centred_real = standardized_output[:, :, :input_dim]
            centred_imag = standardized_output[:, :, input_dim:]
        elif axis == -1 and ndim == 4:
            centred_real = standardized_output[:, :, :, :input_dim]
            centred_imag = standardized_output[:, :, :, input_dim:]
        elif axis == -1 and ndim == 5:
            centred_real = standardized_output[:, :, :, :, :input_dim]
            centred_imag = standardized_output[:, :, :, :, input_dim:]
        else:
            raise ValueError(
                'Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. '
                'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.'
            )
        rolled_standardized_output = tf.concat([centred_imag, centred_real], axis=axis)
        if center:
            broadcast_beta = tf.reshape(beta, broadcast_beta_shape)
            return cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output + broadcast_beta
        else:
            return cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output
    else:
        if center:
            broadcast_beta = tf.reshape(beta, broadcast_beta_shape)
            return input_centred + broadcast_beta
        else:
            return input_centred
Esempio n. 6
0
    def compile_backprop(self, proj_type='l2'):
        self.proj_type = proj_type

        layers = [layer for layer in self.model.layers if layer.get_weights()]

        A = [Input(layer.output_shape[1:]) for layer in layers[:-1]]

        # Store the constraint gradients and biases for each layer.
        grads, biases = [], []
        prev_grad, prev_bias = None, None

        for i, layer in enumerate(layers):
            W, b = layer.weights

            if isinstance(layer, Dense):
                if i > 0 and K.ndim(A[i - 1]) == 4 and (K.image_data_format()
                                                        == 'channels_first'):

                    # The `Flatten` layer doesn't respect the channels-first
                    # dimension ordering, so it mixes up our dimensions. We need
                    # to correct for that here.
                    _, ch, h, w = K.int_shape(A[i - 1])
                    _, n_out = K.int_shape(W)
                    W = K.reshape(
                        K.permute_dimensions(K.reshape(W, (h, w, ch, n_out)),
                                             (2, 0, 1, 3)),
                        (ch * h * w, n_out))

                if len(grads) == 0:
                    grad = K.transpose(W)
                    bias = b

                    # Expand to batch shape.
                    grad = grad[None] * K.ones_like(A[i])[:, :, None]
                    bias = bias[None] * K.ones_like(A[i])

                else:
                    A_i = K.reshape(
                        A[i - 1], [-1, np.prod(K.int_shape(A[i - 1])[1:])])

                    grad = (K.transpose(W)[None] * A_i[:, None]) @ grads[-1]
                    bias = (biases[-1] * A_i) @ W + b[None]

                grads.append(grad)
                biases.append(bias)

            else:
                if K.image_data_format() == 'channels_first':
                    _, ch_in, h_in, w_in = layer.input_shape
                    _, ch_out, h_out, w_out = layer.output_shape
                else:
                    _, h_in, w_in, ch_in = layer.input_shape
                    _, h_out, w_out, ch_out = layer.output_shape

                if len(grads) == 0:
                    if K.image_data_format() == 'channels_first':
                        grad = K.conv2d(K.reshape(
                            K.eye(ch_in * h_in * w_in),
                            [ch_in * h_in * w_in, ch_in, h_in, w_in]),
                                        W,
                                        padding=layer.padding,
                                        strides=layer.strides)

                        bias = K.tile(b[:, None, None], [1, h_out, w_out])

                    else:
                        grad = K.conv2d(K.reshape(
                            K.eye(ch_in * h_in * w_in),
                            [ch_in * h_in * w_in, h_in, w_in, ch_in]),
                                        W,
                                        padding=layer.padding,
                                        strides=layer.strides)

                        bias = K.tile(b[None, None], [h_out, w_out, 1])

                    # Expand to batch shape.
                    grad = grad[None] * K.ones_like(A[i])[:, None]
                    bias = bias[None] * K.ones_like(A[i])

                else:

                    n = np.prod(self.input_shape)

                    if K.image_data_format() == 'channels_first':
                        grad = K.reshape(
                            K.conv2d(K.reshape(grad * A[i - 1][:, None],
                                               (-1, ch_in, h_in, w_in)),
                                     W,
                                     padding=layer.padding,
                                     strides=layer.strides),
                            (-1, n, ch_out, h_out, w_out))

                        bias = K.conv2d(bias * A[i - 1],
                                        W,
                                        padding=layer.padding,
                                        strides=layer.strides) + b[None, :,
                                                                   None, None]

                    else:
                        grad = K.reshape(
                            K.conv2d(K.reshape(grad * A[i - 1][:, None],
                                               (-1, h_in, h_in, ch_in)),
                                     W,
                                     padding=layer.padding,
                                     strides=layer.strides),
                            (-1, n, h_out, h_out, ch_out))

                        bias = K.conv2d(bias * A[i - 1],
                                        W,
                                        padding=layer.padding,
                                        strides=layer.strides) + b[None, None,
                                                                   None]

                grads.append(
                    K.permute_dimensions(
                        K.reshape(grad, (-1, n, ch_out * h_out * w_out)),
                        (0, 2, 1)))
                biases.append(K.batch_flatten(bias))

        # Handle the softmax constraints.
        c = K.placeholder((1, ), dtype='int32')

        softmax_grads = grads[-1]
        softmax_biases = biases[-1]

        c_grad = K.permute_dimensions(
            K.gather(K.permute_dimensions(softmax_grads, (1, 0, 2)), c),
            (1, 0, 2))

        c_bias = K.transpose(K.gather(K.transpose(softmax_biases), c))

        grads[-1] = softmax_grads - c_grad
        biases[-1] = softmax_biases - c_bias

        grads_no_first_layer = K.concatenate(grads[1:], axis=1)
        biases_no_first_layer = K.concatenate(biases[1:], axis=1)

        grads = K.concatenate(grads, axis=1)
        biases = K.concatenate(biases, axis=1)

        # Calculate distances.
        x = K.placeholder(self.input_shape)

        distances = proj_dist(proj_type, K.reshape(x, (1, -1)), grads, biases)

        distances_no_first_layer = proj_dist(proj_type, K.reshape(x, (1, -1)),
                                             grads_no_first_layer,
                                             biases_no_first_layer)

        self._grad_f = K.function(A + [c], [grads])
        self._bias_f = K.function(A + [c], [biases])
        self._dist_f = K.function(A + [c, x], [distances])
        self._all_f = K.function(A + [c, x], [
            distances, grads[:, -self.n_classes:], biases[:, -self.n_classes:]
        ])

        self._all_except_first_f = K.function(A + [c, x], [
            distances_no_first_layer, grads_no_first_layer[:,
                                                           -self.n_classes:],
            biases_no_first_layer[:, -self.n_classes:]
        ])

        self.compiled = True

        return self
Esempio n. 7
0
    def call(self, inputs, **kwargs):
        input_shape = K.shape(inputs)
        ndim = K.ndim(inputs)
        reduction_axes = list(range(ndim))
        del reduction_axes[self.axis]
        del reduction_axes[0]
        input_dim = input_shape[self.axis] // 2
        mu = K.mean(inputs, axis=reduction_axes)
        broadcast_mu_shape = [1] * ndim
        broadcast_mu_shape[self.axis] = input_shape[self.axis]
        broadcast_mu_shape[0] = K.shape(inputs)[0]
        broadcast_mu = K.reshape(mu, broadcast_mu_shape)
        if self.center:
            input_centred = inputs - broadcast_mu
        else:
            input_centred = inputs
        centred_squared = input_centred**2
        if (self.axis == 1 and ndim != 3) or ndim == 2:
            centred_squared_real = centred_squared[:, :input_dim]
            centred_squared_imag = centred_squared[:, input_dim:]
            centred_real = input_centred[:, :input_dim]
            centred_imag = input_centred[:, input_dim:]
        elif ndim == 3:
            centred_squared_real = centred_squared[:, :, :input_dim]
            centred_squared_imag = centred_squared[:, :, input_dim:]
            centred_real = input_centred[:, :, :input_dim]
            centred_imag = input_centred[:, :, input_dim:]
        elif self.axis == -1 and ndim == 4:
            centred_squared_real = centred_squared[:, :, :, :input_dim]
            centred_squared_imag = centred_squared[:, :, :, input_dim:]
            centred_real = input_centred[:, :, :, :input_dim]
            centred_imag = input_centred[:, :, :, input_dim:]
        elif self.axis == -1 and ndim == 5:
            centred_squared_real = centred_squared[:, :, :, :, :input_dim]
            centred_squared_imag = centred_squared[:, :, :, :, input_dim:]
            centred_real = input_centred[:, :, :, :, :input_dim]
            centred_imag = input_centred[:, :, :, :, input_dim:]
        else:
            raise ValueError(
                'Incorrect Layernorm combination of axis and dimensions. axis should be either 1 or -1. '
                'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.')
        if self.scale:
            Vrr = K.mean(centred_squared_real,
                         axis=reduction_axes) + self.epsilon
            Vii = K.mean(centred_squared_imag,
                         axis=reduction_axes) + self.epsilon
            # Vri contains the real and imaginary covariance for each feature map.
            Vri = K.mean(
                centred_real * centred_imag,
                axis=reduction_axes,
            )
        elif self.center:
            Vrr = None
            Vii = None
            Vri = None
        else:
            raise ValueError(
                'Error. Both scale and center in batchnorm are set to False.')

        return complex_normalization(input_centred,
                                     Vrr,
                                     Vii,
                                     Vri,
                                     self.beta,
                                     self.gamma_rr,
                                     self.gamma_ri,
                                     self.gamma_ii,
                                     self.scale,
                                     self.center,
                                     layernorm=True,
                                     axis=self.axis)
Esempio n. 8
0
    def call(self, inputs, **kwargs):
        """Constructs the NMS graph.

        Args:
            inputs: List of
                ``[boxes, classification, other[0], other[1], ...]`` tensors.
        """
        boxes = inputs[0]
        classification = inputs[1]
        other = inputs[2:]

        time_distributed = K.ndim(boxes) == 4

        if time_distributed:
            boxes_shape = K.shape(boxes)
            # classification_shape = classification.get_shape()
            classification_shape = K.shape(classification)
            other_shape = [K.shape(o) for o in other]

            new_boxes_shape = [-1] + [boxes_shape[i] for i in range(2, K.ndim(boxes))]
            new_classification_shape = [-1] + \
                [classification_shape[i] for i in range(2, K.ndim(classification) - 1)] + \
                [classification.get_shape()[-1]]
            new_other_shape = [[-1] + [o_s[i] for i in range(2, K.ndim(o))]
                               for o, o_s in zip(other, other_shape)]

            boxes = K.reshape(boxes, new_boxes_shape)
            classification = K.reshape(classification, new_classification_shape)
            other = [K.reshape(o, o_s) for o, o_s in zip(other, new_other_shape)]

        # wrap nms with our parameters
        def _filter_detections(args):
            boxes = args[0]
            classification = args[1]
            other = args[2]

            return filter_detections(
                boxes,
                classification,
                other,
                nms=self.nms,
                class_specific_filter=self.class_specific_filter,
                score_threshold=self.score_threshold,
                max_detections=self.max_detections,
                nms_threshold=self.nms_threshold,
            )

        # call filter_detections on each batch
        outputs = tf.map_fn(
            _filter_detections,
            elems=[boxes, classification, other],
            dtype=[K.floatx(), K.floatx(), 'int32'] + [o.dtype for o in other],
            parallel_iterations=self.parallel_iterations
        )

        if time_distributed:
            filtered_boxes = outputs[0]
            filtered_scores = outputs[1]
            filtered_labels = outputs[2]
            filtered_other = outputs[3:]

            final_boxes_shape = [boxes_shape[0], boxes_shape[1], self.max_detections, 4]
            final_scores_shape = [
                classification_shape[0],
                classification_shape[1],
                self.max_detections
            ]
            final_labels_shape = [
                classification_shape[0],
                classification_shape[1],
                self.max_detections
            ]
            final_others_shape = [[o[0], o[1], self.max_detections] +
                                  [o[i] for i in range(3, K.ndim(o))]
                                  for o in other_shape]

            filtered_boxes = K.reshape(filtered_boxes, final_boxes_shape)
            filtered_scores = K.reshape(filtered_scores, final_scores_shape)
            filtered_labels = K.reshape(filtered_labels, final_labels_shape)
            filtered_other = [K.reshape(o, o_s) for o, o_s in zip(filtered_other,
                                                                  final_others_shape)]

            outputs = [filtered_boxes, filtered_scores, filtered_labels] + filtered_other

        return outputs
Esempio n. 9
0
def filter_detections(boxes,
                      classification,
                      other=[],
                      class_specific_filter=True,
                      nms=True,
                      score_threshold=0.05,
                      max_detections=300,
                      nms_threshold=0.5):
    """Filter detections using the boxes and classification values.

    Args:
        boxes (numpy.array): Tensor of shape ``(num_boxes, 4)`` containing the
            boxes in ``(x1, y1, x2, y2)`` format.
        classification (numpy.array): Tensor of shape
            ``(num_boxes, num_classes)`` containing the classification scores.
        other (list): List of tensors of shape ``(num_boxes, ...)`` to filter
            along with the boxes and classification scores.
        class_specific_filter (bool): Whether to perform filtering per class,
            or take the best scoring class and filter those.
        nms (bool): Whether to enable non maximum suppression.
        score_threshold (float): Threshold used to prefilter the boxes with.
        max_detections (int): Maximum number of detections to keep.
        nms_threshold (float): Threshold for the IoU value to determine when a
            box should be suppressed.

    Returns:
        list: A list of [``boxes, scores, labels, other[0], other[1], ...]``.
        ``boxes`` is shaped ``(max_detections, 4)`` and contains the
        ``(x1, y1, x2, y2)`` of the non-suppressed boxes.
        ``scores`` is shaped ``(max_detections,)`` and contains the scores
        of the predicted class.
        ``labels`` is shaped ``(max_detections,)`` and contains the
        predicted label.
        ``other[i]`` is shaped ``(max_detections, ...)`` and contains the
        filtered ``other[i]`` data.
        In case there are less than ``max_detections`` detections,
        the tensors are padded with -1's.
    """
    def _filter_detections(scores, labels):
        # threshold based on score
        indices = tf.where(K.greater(scores, score_threshold))

        if nms:
            filtered_boxes = tf.gather_nd(boxes, indices)
            filtered_scores = K.gather(scores, indices)[:, 0]

            # perform NMS
            nms_indices = tf.image.non_max_suppression(
                filtered_boxes,
                filtered_scores,
                max_output_size=max_detections,
                iou_threshold=nms_threshold)

            # filter indices based on NMS
            indices = K.gather(indices, nms_indices)

        # add indices to list of all indices
        labels = tf.gather_nd(labels, indices)
        indices = K.stack([indices[:, 0], labels], axis=1)

        return indices

    if class_specific_filter:
        all_indices = []
        # perform per class filtering
        for c in range(K.int_shape(classification)[1]):
            scores = classification[:, c]
            labels = c * tf.ones((K.shape(scores)[0],), dtype='int64')
            all_indices.append(_filter_detections(scores, labels))

        # concatenate indices to single tensor
        indices = K.concatenate(all_indices, axis=0)
    else:
        scores = K.max(classification, axis=1)
        labels = K.argmax(classification, axis=1)
        indices = _filter_detections(scores, labels)

    # select top k
    scores = tf.gather_nd(classification, indices)
    labels = indices[:, 1]
    scores, top_indices = tf.nn.top_k(
        scores, k=K.minimum(max_detections, K.shape(scores)[0]))

    # filter input using the final set of indices
    indices = K.gather(indices[:, 0], top_indices)
    boxes = K.gather(boxes, indices)
    labels = K.gather(labels, top_indices)
    other_ = [K.gather(o, indices) for o in other]

    # zero pad the outputs
    pad_size = K.maximum(0, max_detections - K.shape(scores)[0])
    boxes = tf.pad(boxes, [[0, pad_size], [0, 0]], constant_values=-1)
    scores = tf.pad(scores, [[0, pad_size]], constant_values=-1)
    labels = tf.pad(labels, [[0, pad_size]], constant_values=-1)
    labels = K.cast(labels, 'int32')
    pads = lambda x: [[0, pad_size]] + [[0, 0] for _ in range(1, K.ndim(x))]
    other_ = [tf.pad(o, pads(o), constant_values=-1) for o in other_]

    # set shapes, since we know what they are
    boxes.set_shape([max_detections, 4])
    scores.set_shape([max_detections])
    labels.set_shape([max_detections])
    for o, s in zip(other_, [list(K.int_shape(o)) for o in other]):
        o.set_shape([max_detections] + s[1:])

    return [boxes, scores, labels] + other_
Esempio n. 10
0
 def call(self, inputs):
     z, beta, gamma = inputs
     for i in range(K.ndim(z) - 2):
         beta = K.expand_dims(beta, 1)
         gamma = K.expand_dims(gamma, 1)
     return z * (gamma + 1) + beta
Esempio n. 11
0
 def content_loss(img1, img2):
     assert K.int_shape(img1) == K.int_shape(img2)
     assert K.ndim(img1) == 4
     batch, height, width, channels = K.int_shape(img1)
     return K.sum(
         (img1 - img2)**2, axis=(1, 2, 3)) / height / width / channels / 2
Esempio n. 12
0
        def _mask_batch(y_true,
                        y_pred,
                        iou_threshold=0.5,
                        mask_size=(28, 28),
                        parallel_iterations=32):
            if K.ndim(y_pred) == 4:
                y_pred_shape = tf.shape(y_pred)
                new_y_pred_shape = [
                    y_pred_shape[0] * y_pred_shape[1], y_pred_shape[2],
                    y_pred_shape[3]
                ]
                y_pred = tf.reshape(y_pred, new_y_pred_shape)

                y_true_shape = tf.shape(y_true)
                new_y_true_shape = [
                    y_true_shape[0] * y_true_shape[1], y_true_shape[2],
                    y_true_shape[3]
                ]
                y_true = tf.reshape(y_true, new_y_true_shape)

            # split up the different predicted blobs
            boxes = y_pred[:, :, :4]
            masks = y_pred[:, :, 4:]

            # split up the different blobs
            annotations = y_true[:, :, :5]
            width = K.cast(y_true[0, 0, 5], dtype='int32')
            height = K.cast(y_true[0, 0, 6], dtype='int32')
            masks_target = y_true[:, :, 7:]

            # reshape the masks back to their original size
            masks_target = K.reshape(masks_target,
                                     (K.shape(masks_target)[0],
                                      K.shape(masks_target)[1], height, width))
            masks = K.reshape(masks, (K.shape(masks)[0], K.shape(masks)[1],
                                      mask_size[0], mask_size[1], -1))

            def _mask(args):
                boxes = args[0]
                masks = args[1]
                annotations = args[2]
                masks_target = args[3]

                return compute_mask_loss(
                    boxes,
                    masks,
                    annotations,
                    masks_target,
                    width,
                    height,
                    iou_threshold=iou_threshold,
                    mask_size=mask_size,
                )

            mask_batch_loss = tf.map_fn(
                _mask,
                elems=[boxes, masks, annotations, masks_target],
                dtype=K.floatx(),
                parallel_iterations=parallel_iterations)

            return K.mean(mask_batch_loss)
Esempio n. 13
0
 def temp_norm(ten, axis=None):
     if axis is None:
         axis = 1 if K.image_data_format(
         ) == 'channels_first' else K.ndim(ten) - 1
     return K.sqrt(K.epsilon() + K.sum(K.square(ten), axis=axis))
Esempio n. 14
0
    def call(self, x, mask=None):
        if self.mode == 0 or self.mode == 2:
            assert self.built, 'Layer must be built before being called'
            input_shape = K.int_shape(x)

            reduction_axes = list(range(len(input_shape)))
            del reduction_axes[self.axis]
            broadcast_shape = [1] * len(input_shape)
            broadcast_shape[self.axis] = input_shape[self.axis]

            # mean_batch, var_batch = K.moments(x, reduction_axes, shift=None, keep_dims=False)
            normed, mean_batch, var_batch = K.normalize_batch_in_training(
                x, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)

            std_batch = (K.sqrt(var_batch + self.epsilon))

            r_max_value = K.get_value(self.r_max)
            r = std_batch / (K.sqrt(self.running_std + self.epsilon))
            r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

            d_max_value = K.get_value(self.d_max)
            d = (mean_batch - self.running_mean) / K.sqrt(self.running_std +
                                                          self.epsilon)
            d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

            if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                x_normed_batch = (x - mean_batch) / std_batch
                x_normed = (x_normed_batch * r + d) * self.gamma + self.beta
            else:
                # need broadcasting
                broadcast_mean = K.reshape(mean_batch, broadcast_shape)
                broadcast_std = K.reshape(std_batch, broadcast_shape)
                broadcast_r = K.reshape(r, broadcast_shape)
                broadcast_d = K.reshape(d, broadcast_shape)
                broadcast_beta = K.reshape(self.beta, broadcast_shape)
                broadcast_gamma = K.reshape(self.gamma, broadcast_shape)

                x_normed_batch = (x - broadcast_mean) / broadcast_std
                x_normed = (x_normed_batch * broadcast_r +
                            broadcast_d) * broadcast_gamma + broadcast_beta

            # explicit update to moving mean and standard deviation
            self.add_update([
                K.moving_average_update(self.running_mean, mean_batch,
                                        self.momentum),
                K.moving_average_update(self.running_std, std_batch**2,
                                        self.momentum)
            ], x)

            # update r_max and d_max
            t_val = K.get_value(self.t)
            r_val = self.r_max_value / (
                1 + (self.r_max_value - 1) * np.exp(-t_val))
            d_val = self.d_max_value / (1 + (
                (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
            t_val += float(self.t_delta)

            self.add_update([
                K.update(self.r_max, r_val),
                K.update(self.d_max, d_val),
                K.update(self.t, t_val)
            ], x)

            if self.mode == 0:
                if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                    x_normed_running = K.batch_normalization(
                        x,
                        self.running_mean,
                        self.running_std,
                        self.beta,
                        self.gamma,
                        epsilon=self.epsilon)
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_std,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        x,
                        broadcast_running_mean,
                        broadcast_running_std,
                        broadcast_beta,
                        broadcast_gamma,
                        epsilon=self.epsilon)

                # pick the normalized form of x corresponding to the training phase
                # for batch renormalization, inference time remains same as batchnorm
                x_normed = K.in_train_phase(x_normed, x_normed_running)

        elif self.mode == 1:
            # sample-wise normalization
            m = K.mean(x, axis=self.axis, keepdims=True)
            std = K.sqrt(
                K.var(x, axis=self.axis, keepdims=True) + self.epsilon)
            x_normed_batch = (x - m) / (std + self.epsilon)

            r_max_value = K.get_value(self.r_max)
            r = std / (self.running_std + self.epsilon)
            r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

            d_max_value = K.get_value(self.d_max)
            d = (m - self.running_mean) / (self.running_std + self.epsilon)
            d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

            x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta

            # update r_max and d_max
            t_val = K.get_value(self.t)
            r_val = self.r_max_value / (
                1 + (self.r_max_value - 1) * np.exp(-t_val))
            d_val = self.d_max_value / (1 + (
                (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
            t_val += float(self.t_delta)

            self.add_update([
                K.update(self.r_max, r_val),
                K.update(self.d_max, d_val),
                K.update(self.t, t_val)
            ], x)

        return x_normed
Esempio n. 15
0
 def to_vector(x):
     if K.ndim(x) == 1 and K.dtype(x).startswith('int'):
         x = K.one_hot(x, self.env.observation_space.n)
     elif K.ndim(S) > 2:
         x = keras.layers.Flatten()(x)
     return x
Esempio n. 16
0
    def call(self, inputs, training=None):
        assert self.built, 'Layer must be built before being called'
        input_shape = K.int_shape(inputs)

        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        mean_batch, var_batch = K.moments(inputs,
                                          reduction_axes,
                                          shift=None,
                                          keep_dims=False)
        std_batch = (K.sqrt(var_batch + self.epsilon))

        r_max_value = K.get_value(self.r_max)
        r = std_batch / (K.sqrt(self.running_variance + self.epsilon))
        r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

        d_max_value = K.get_value(self.d_max)
        d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance +
                                                      self.epsilon)
        d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

        if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
            x_normed_batch = (inputs - mean_batch) / std_batch
            x_normed = (x_normed_batch * r + d) * self.gamma + self.beta
        else:
            # need broadcasting
            broadcast_mean = K.reshape(mean_batch, broadcast_shape)
            broadcast_std = K.reshape(std_batch, broadcast_shape)
            broadcast_r = K.reshape(r, broadcast_shape)
            broadcast_d = K.reshape(d, broadcast_shape)
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)

            x_normed_batch = (inputs - broadcast_mean) / broadcast_std
            x_normed = (x_normed_batch * broadcast_r +
                        broadcast_d) * broadcast_gamma + broadcast_beta

        # explicit update to moving mean and standard deviation
        self.add_update([
            K.moving_average_update(self.running_mean, mean_batch,
                                    self.momentum),
            K.moving_average_update(self.running_variance, std_batch**2,
                                    self.momentum)
        ], inputs)

        # update r_max and d_max
        t_val = K.get_value(self.t)
        r_val = self.r_max_value / (1 +
                                    (self.r_max_value - 1) * np.exp(-t_val))
        d_val = self.d_max_value / (1 + (
            (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
        t_val += float(self.t_delta)

        self.add_update([
            K.update(self.r_max, r_val),
            K.update(self.d_max, d_val),
            K.update(self.t, t_val)
        ], inputs)

        if training in {0, False}:
            return x_normed
        else:

            def normalize_inference():
                if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
                    x_normed_running = K.batch_normalization(
                        inputs,
                        self.running_mean,
                        self.running_variance,
                        self.beta,
                        self.gamma,
                        epsilon=self.epsilon)

                    return x_normed_running
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_variance,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        inputs,
                        broadcast_running_mean,
                        broadcast_running_std,
                        broadcast_beta,
                        broadcast_gamma,
                        epsilon=self.epsilon)

                    return x_normed_running

            # pick the normalized form of inputs corresponding to the training phase
            # for batch renormalization, inference time remains same as batchnorm
            x_normed = K.in_train_phase(x_normed,
                                        normalize_inference,
                                        training=training)

            return x_normed
Esempio n. 17
0
    def call(self, x, **kwargs):
        if K.ndim(x) != 3:
            raise ValueError(
                f'Wrong dimensions of inputs, expected 3 but input {K.ndim(x)}.'
            )

        dim = int(x.get_shape()[-1])
        hidden_nn_layers = [x]
        final_result = []

        split_tensor0 = tf.split(hidden_nn_layers[0], dim * [1], 2)
        for idx, layer_size in enumerate(self.cross_layer_size):
            split_tensor = tf.split(hidden_nn_layers[-1], dim * [1], 2)
            dot_result_m = tf.matmul(split_tensor0,
                                     split_tensor,
                                     transpose_b=True)
            dot_result_o = tf.reshape(
                dot_result_m,
                shape=[dim, -1, self.field_nums[0] * self.field_nums[idx]])
            dot_result = tf.transpose(dot_result_o, perm=[1, 0, 2])

            if self.reduce_D:
                f0_ = self.f0_[idx]
                f__ = self.f__[idx]
                f_m = tf.matmul(f0_, f__)
                f_o = tf.reshape(f_m,
                                 shape=[
                                     1, layer_size,
                                     self.field_nums[0] * self.field_nums[idx]
                                 ])
                filters = tf.transpose(f_o, perm=[0, 2, 1])
            else:
                filters = self.f_[idx]
            curr_out = tf.nn.conv1d(dot_result,
                                    filters=filters,
                                    stride=1,
                                    padding='VALID')
            if self.use_bias:
                curr_out = tf.nn.bias_add(curr_out, self.bias[idx])

            curr_out = self.activation_layers[idx](curr_out)
            curr_out = tf.transpose(curr_out, perm=[0, 2, 1])

            if self.direct:
                direct_connect = curr_out
                next_hidden = curr_out
            else:
                if idx != len(self.cross_layer_size) - 1:
                    next_hidden, direct_connect = tf.split(
                        curr_out, 2 * [layer_size // 2], 1)
                else:
                    direct_connect = curr_out
                    next_hidden = 0

            final_result.append(direct_connect)
            hidden_nn_layers.append(next_hidden)

        result = tf.concat(final_result, axis=1)
        result = tf.reduce_sum(result, -1)

        if self.use_residual:
            exFM_out0 = self.exFM_out0(result)
            exFM_in = tf.concat([exFM_out0, result], axis=1)
            exFM_out = self.exFM_out(exFM_in)
        else:
            exFM_out = self.exFM_out(result)
        return exFM_out
Esempio n. 18
0
def predict(predict_var,
            x_unlabeled,
            inputs,
            y_true,
            batch_sizes,
            x_labeled=None,
            y_labeled=None):
  """Evaluates predict_var, batchwise, over all points in x_unlabeled and x_labeled.

  Args:
    predict_var:        list of tensors to evaluate and return
    x_unlabeled:        unlabeled input data
    inputs:             dictionary containing input_types and input_placeholders
      as key, value pairs, respectively
    y_true:             true labels tensorflow placeholder
    batch_sizes:        dictionary containing input_types and batch_sizes as
      key, value pairs, respectively
    x_labeled:          labeled input data
    y_labeled:          labeled input labels

  Returns:
    a list of length n containing the result of all tensors
    in return_var, where n = len(x_unlabeled) + len(x_labeled)
  """
  x_unlabeled, x_labeled, y_labeled = check_inputs(x_unlabeled, x_labeled,
                                                   y_labeled, y_true)

  # combined data
  x = np.concatenate((x_unlabeled, x_labeled), 0)
  # get shape of y_true
  y_shape = y_true.get_shape()[1:K.ndim(y_true)].as_list()

  # calculate batches for predict loop
  unlabeled_batch_size = batch_sizes.get('Unlabeled', 0)
  labeled_batch_size = batch_sizes.get('Labeled', 0)
  if 'Labeled' in batch_sizes and 'Unlabeled' in batch_sizes:
    assert unlabeled_batch_size == labeled_batch_size
  batch_size = min(len(x), max(unlabeled_batch_size, labeled_batch_size))
  batches = make_batches(len(x), batch_size)

  y_preds = []
  # predict over all points
  for _, (batch_start, batch_end) in enumerate(batches):
    feed_dict = {K.learning_phase(): 0}

    # feed corresponding input for each input_type
    for input_type, input_placeholder in inputs.items():
      if input_type == 'Unlabeled':
        feed_dict[input_placeholder] = x[batch_start:batch_end]
      elif input_type == 'Labeled':
        if x_labeled:
          batch_ids = np.random.choice(
              len(x_labeled),
              size=min(batch_sizes[input_type], len(x_labeled)),
              replace=False)
          feed_dict[input_placeholder] = x_labeled[batch_ids]
          feed_dict[y_true] = y_labeled[batch_ids]
        else:
          # we have no labeled points, so feed an empty array
          feed_dict[input_placeholder] = x[0:0]
          feed_dict[y_true] = np.empty([0] + y_shape)

    # evaluate the batch
    y_pred_batch = np.asarray(K.get_session().run(
        predict_var, feed_dict=feed_dict))
    y_preds.append(y_pred_batch)

  if y_preds[0].shape:
    return np.concatenate(y_preds)
  else:
    return np.sum(y_preds)
Esempio n. 19
0
    def policy_loss_with_metrics(self, Adv, A):
        """

        This method constructs the policy loss as a scalar-valued Tensor,
        together with a dictionary of metrics (also scalars).

        This method may be overridden to construct a custom policy loss and/or
        to change the accompanying metrics.

        Parameters
        ----------
        Adv : 1d Tensor, shape: [batch_size]

            A batch of advantages.

        A : nd Tensor, shape: [batch_size, ...]

            A batch of actions taken under the behavior policy.

        Returns
        -------
        loss, metrics : (Tensor, dict of Tensors)

            The policy loss along with some metrics, which is a dict of type
            ``{name <str>: metric <Tensor>}``. The loss and each of the metrics
            (dict values) are scalar Tensors, i.e. Tensors with ``ndim=0``.

            The ``loss`` is passed to a keras Model using
            ``train_model.add_loss(loss)``. Similarly, each metric in the
            metric dict is passed to the model using
            ``train_model.add_metric(metric, name=name, aggregation='mean')``.


        """
        Adv = K.stop_gradient(Adv)
        if K.ndim(Adv) == 2:
            Adv = K.squeeze(Adv, axis=1)
        check_tensor(Adv, ndim=1)

        if self.update_strategy == 'vanilla':

            log_pi = self.dist.log_proba(A)
            check_tensor(log_pi, same_as=Adv)

            entropy = K.mean(self.dist.entropy())

            # flip sign to get loss from objective
            loss = -K.mean(Adv * log_pi) + self.entropy_beta * entropy

            # no metrics related to behavior_dist since its not used in loss
            metrics = {'policy/entropy': entropy}

        elif self.update_strategy == 'ppo':

            log_pi = self.dist.log_proba(A)
            log_pi_old = K.stop_gradient(self.target_dist.log_proba(A))
            check_tensor(log_pi, same_as=Adv)
            check_tensor(log_pi_old, same_as=Adv)

            eps = self.ppo_clip_eps
            ratio = K.exp(log_pi - log_pi_old)
            ratio_clip = K.clip(ratio, 1 - eps, 1 + eps)
            check_tensor(log_pi, same_as=Adv)
            check_tensor(log_pi_old, same_as=Adv)

            clip_objective = K.mean(K.minimum(Adv * ratio, Adv * ratio_clip))
            entropy = K.mean(self.dist.entropy())
            kl_div = K.mean(self.target_dist.kl_divergence(self.dist))

            # flip sign to get loss from objective
            loss = -(clip_objective + self.entropy_beta * entropy)
            metrics = {'policy/entropy': entropy, 'policy/kl_div': kl_div}

        elif self.update_strategy == 'cross_entropy':
            raise NotImplementedError('cross_entropy')

        else:
            raise ValueError("unknown update_strategy '{}'".format(
                self.update_strategy))

        # rename
        loss = tf.identity(loss, name='policy_loss')

        return loss, metrics
Esempio n. 20
0
def train_step(return_vars,
               updates,
               x_unlabeled,
               inputs,
               y_true,
               batch_sizes,
               x_labeled=None,
               y_labeled=None,
               batches_per_epoch=100):
  """Performs one training step.

   Evaluates the tensors in return_vars and updates, then returns the values of
   the tensors in return_vars.

  Args:
    return_vars: list of tensors to evaluate and return
    updates: list of tensors to evaluate only
    x_unlabeled: unlabeled input data
    inputs: dictionary containing input_types and input_placeholders as key,
      value pairs, respectively
    y_true: true labels placeholder
    batch_sizes: dictionary containing input_types and batch_sizes as key, value
      pairs, respectively
    x_labeled: labeled input data
    y_labeled: labeled input labels
    batches_per_epoch: parameter updates per epoch*

  Returns:
    the evaluated result of all tensors in return_vars, summed
    across all epochs

  *note: the term epoch is used loosely here, it does not necessarily
         refer to one iteration over the entire dataset. instead, it
         is just batches_per_epoch parameter updates.
  """
  x_unlabeled, x_labeled, y_labeled = check_inputs(x_unlabeled, x_labeled,
                                                   y_labeled, y_true)

  # combine data
  x = np.concatenate((x_unlabeled, x_labeled), 0)

  # get shape of y_true
  y_shape = y_true.get_shape()[1:K.ndim(y_true)].as_list()

  return_vars_ = np.zeros(shape=(len(return_vars)))
  # train batches_per_epoch batches
  for _ in range(0, batches_per_epoch):
    feed_dict = {K.learning_phase(): 1}

    # feed corresponding input for each input_type
    for input_type, input_placeholder in inputs.items():
      if input_type == 'Labeled':
        if x_labeled:
          batch_ids = np.random.choice(
              len(x_labeled),
              size=min(batch_sizes[input_type], len(x_labeled)),
              replace=False)
          feed_dict[input_placeholder] = x_labeled[batch_ids]
          feed_dict[y_true] = y_labeled[batch_ids]
        else:
          # we have no labeled points, so feed an empty array
          feed_dict[input_placeholder] = x[0:0]
          feed_dict[y_true] = np.empty([0] + y_shape)
      elif input_type == 'Unlabeled':
        if x_unlabeled:
          batch_ids = np.random.choice(
              len(x_unlabeled), size=batch_sizes[input_type], replace=False)
          feed_dict[input_placeholder] = x_unlabeled[batch_ids]
        else:
          # we have no unlabeled points, so feed an empty array
          feed_dict[input_placeholder] = x[0:0]

    all_vars = return_vars + updates
    return_vars_ += np.asarray(K.get_session().run(
        all_vars, feed_dict=feed_dict)[:len(return_vars)])

  return return_vars_
Esempio n. 21
0
 def softmax_over_time(self, x):
     assert (K.ndim(x) > 2), "x dims too small"
     e = K.exp(x - K.max(x, axis=1, keepdims=True))
     s = K.sum(e, axis=1, keepdims=True)
     return e / s
 def _to_vector(X, space):
     if K.ndim(X) == 1 and K.dtype(X).startswith('int'):
         X = K.one_hot(X, space.n)
     elif K.ndim(X) > 2:
         X = keras.layers.Flatten()(X)
     return X
Esempio n. 23
0
def loss_per_pixel(mask, y_true, y_pred):
    """Pixel L1 loss outside the hole / mask"""
    assert K.ndim(y_true) == 4, 'Input tensor should be 4D (B, H, W, C).'
    return K.mean(K.abs(mask * (y_pred - y_true)), axis=[1, 2, 3])
Esempio n. 24
0
    def call(self, inputs):
        if len(inputs) == 3:
            X, A, I = inputs
            if K.ndim(I) == 2:
                I = I[:, 0]
        else:
            X, A = inputs
            I = None

        N = K.shape(A)[-1]
        # Check if the layer is operating in mixed or batch mode
        mode = ops.autodetect_mode(A, X)
        self.reduce_loss = mode in (modes.MIXED, modes.BATCH)

        # Get normalized adjacency
        if K.is_sparse(A):
            I_ = tf.sparse.eye(N, dtype=A.dtype)
            A_ = tf.sparse.add(A, I_)
        else:
            I_ = tf.eye(N, dtype=A.dtype)
            A_ = A + I_
        fltr = ops.normalize_A(A_)

        # Node embeddings
        Z = K.dot(X, self.kernel_emb)
        Z = ops.filter_dot(fltr, Z)
        if self.activation is not None:
            Z = self.activation(Z)

        # Compute cluster assignment matrix
        S = K.dot(X, self.kernel_pool)
        S = ops.filter_dot(fltr, S)
        S = activations.softmax(S, axis=-1)  # softmax applied row-wise

        # Link prediction loss
        S_gram = ops.matmul_A_BT(S, S)
        if mode == modes.MIXED:
            A = tf.sparse.to_dense(A)[None, ...]
        if K.is_sparse(A):
            LP_loss = tf.sparse.add(
                A, -S_gram)  # A/tf.norm(A) - S_gram/tf.norm(S_gram)
        else:
            LP_loss = A - S_gram
        LP_loss = tf.norm(LP_loss, axis=(-1, -2))
        if self.reduce_loss:
            LP_loss = K.mean(LP_loss)
        self.add_loss(LP_loss)

        # Entropy loss
        entr = tf.negative(
            tf.reduce_sum(tf.multiply(S, K.log(S + K.epsilon())), axis=-1))
        entr_loss = K.mean(entr, axis=-1)
        if self.reduce_loss:
            entr_loss = K.mean(entr_loss)
        self.add_loss(entr_loss)

        # Pooling
        X_pooled = ops.matmul_AT_B(S, Z)
        A_pooled = ops.matmul_AT_B_A(S, A)

        output = [X_pooled, A_pooled]

        if I is not None:
            I_mean = tf.math.segment_mean(I, I)
            I_pooled = ops.repeat(I_mean, tf.ones_like(I_mean) * self.k)
            output.append(I_pooled)

        if self.return_mask:
            output.append(S)

        return output
Esempio n. 25
0
def complex_standardization(input_centred, Vrr, Vii, Vri,
                            layernorm=False, axis=-1):
    
    ndim = K.ndim(input_centred)
    input_dim = tf.shape(input_centred)[axis] // 2
    variances_broadcast = [1] * ndim
    variances_broadcast[axis] = input_dim
    if layernorm:
        variances_broadcast[0] = tf.shape(input_centred)[0]

    # We require the covariance matrix's inverse square root. That first requires
    # square rooting, followed by inversion (I do this in that order because during
    # the computation of square root we compute the determinant we'll need for
    # inversion as well).

    # tau = Vrr + Vii = Trace. Guaranteed >= 0 because SPD
    tau = Vrr + Vii
    # delta = (Vrr * Vii) - (Vri ** 2) = Determinant. Guaranteed >= 0 because SPD
    delta = (Vrr * Vii) - (Vri ** 2)

    s = tf.sqrt(delta) # Determinant of square root matrix
    t = tf.sqrt(tau + 2 * s)
    # The square root matrix could now be explicitly formed as
    #       [ Vrr+s Vri   ]
    # (1/t) [ Vir   Vii+s ]
    # https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix
    # but we don't need to do this immediately since we can also simultaneously
    # invert. We can do this because we've already computed the determinant of
    # the square root matrix, and can thus invert it using the analytical
    # solution for 2x2 matrices
    #      [ A B ]             [  D  -B ]
    # inv( [ C D ] ) = (1/det) [ -C   A ]
    # http://mathworld.wolfram.com/MatrixInverse.html
    # Thus giving us
    #           [  Vii+s  -Vri   ]
    # (1/delta)(1/t)[ -Vir     Vrr+s ]
    # So we proceed as follows:

    inverse_st = 1.0 / (s * t)
    Wrr = (Vii + s) * inverse_st
    Wii = (Vrr + s) * inverse_st
    Wri = -Vri * inverse_st

    # And we have computed the inverse square root matrix W = sqrt(V)!
    # Normalization. We multiply, x_normalized = W.x.

    # The returned result will be a complex standardized input
    # where the real and imaginary parts are obtained as follows:
    # x_real_normed = Wrr * x_real_centred + Wri * x_imag_centred
    # x_imag_normed = Wri * x_real_centred + Wii * x_imag_centred

    broadcast_Wrr = tf.reshape(Wrr, variances_broadcast)
    broadcast_Wri = tf.reshape(Wri, variances_broadcast)
    broadcast_Wii = tf.reshape(Wii, variances_broadcast)

    cat_W_4_real = tf.concat([broadcast_Wrr, broadcast_Wii], axis=axis)
    cat_W_4_imag = tf.concat([broadcast_Wri, broadcast_Wri], axis=axis)

    if (axis == 1 and ndim != 3) or ndim == 2:
        centred_real = input_centred[:, :input_dim]
        centred_imag = input_centred[:, input_dim:]
    elif ndim == 3:
        centred_real = input_centred[:, :, :input_dim]
        centred_imag = input_centred[:, :, input_dim:]
    elif axis == -1 and ndim == 4:
        centred_real = input_centred[:, :, :, :input_dim]
        centred_imag = input_centred[:, :, :, input_dim:]
    elif axis == -1 and ndim == 5:
        centred_real = input_centred[:, :, :, :, :input_dim]
        centred_imag = input_centred[:, :, :, :, input_dim:]
    else:
        raise ValueError(
            'Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. '
            'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.'
        )
    rolled_input = tf.concat([centred_imag, centred_real], axis=axis)

    output = cat_W_4_real * input_centred + cat_W_4_imag * rolled_input

    #   Wrr * x_real_centered | Wii * x_imag_centered
    # + Wri * x_imag_centered | Wri * x_real_centered
    # -----------------------------------------------
    # = output

    return output
Esempio n. 26
0
 def body(self, S):
     if K.ndim(S) > 2:
         S = keras.layers.Flatten(name='flatten')(S)
     if self.interaction_layer is not None:
         S = self.interaction_layer(S)
     return S
Esempio n. 27
0
def repeat_(x, k):
    tile_factor = [1, k] + [1] * (K.ndim(x) - 1)
    return K.tile(x[:, None, :], tile_factor)
Esempio n. 28
0
    def call(self, inputs, **kwargs):
        image_shape = K.cast(inputs[0], K.floatx())
        boxes = K.stop_gradient(inputs[1])
        scores = K.stop_gradient(inputs[2])
        fpn = [K.stop_gradient(i) for i in inputs[3:]]

        time_distributed = K.ndim(boxes) == 4

        if time_distributed:
            image_shape = image_shape[1:]

            boxes_shape = tf.shape(boxes)
            scores_shape = tf.shape(scores)
            fpn_shape = [tf.shape(f) for f in fpn]

            new_boxes_shape = [-1] + [
                boxes_shape[i] for i in range(2, K.ndim(boxes))
            ]
            new_scores_shape = [-1] + [
                scores_shape[i] for i in range(2, K.ndim(scores))
            ]
            new_fpn_shape = [[-1] + [f_s[i] for i in range(2, K.ndim(f))]
                             for f, f_s in zip(fpn, fpn_shape)]

            boxes = tf.reshape(boxes, new_boxes_shape)
            scores = tf.reshape(scores, new_scores_shape)
            fpn = [tf.reshape(f, f_s) for f, f_s in zip(fpn, new_fpn_shape)]

        def _roi_align(args):
            boxes = args[0]
            scores = args[1]
            fpn = args[2]

            # compute from which level to get features from
            target_levels = self.map_to_level(boxes)

            # process each pyramid independently
            rois, ordered_indices = [], []
            for i in range(len(fpn)):
                # select the boxes and classification from this pyramid level
                indices = tf.where(K.equal(target_levels, i))
                ordered_indices.append(indices)

                level_boxes = tf.gather_nd(boxes, indices)
                fpn_shape = K.cast(K.shape(fpn[i]), dtype=K.floatx())

                # convert to expected format for crop_and_resize
                x1 = level_boxes[:, 0]
                y1 = level_boxes[:, 1]
                x2 = level_boxes[:, 2]
                y2 = level_boxes[:, 3]
                level_boxes = K.stack([
                    (y1 / image_shape[1] * fpn_shape[0]) / (fpn_shape[0] - 1),
                    (x1 / image_shape[2] * fpn_shape[1]) / (fpn_shape[1] - 1),
                    (y2 / image_shape[1] * fpn_shape[0] - 1) /
                    (fpn_shape[0] - 1),
                    (x2 / image_shape[2] * fpn_shape[1] - 1) /
                    (fpn_shape[1] - 1),
                ],
                                      axis=1)

                # append the rois to the list of rois
                rois.append(
                    tf.image.crop_and_resize(
                        K.expand_dims(fpn[i], axis=0), level_boxes,
                        tf.zeros((K.shape(level_boxes)[0], ), dtype='int32'),
                        self.crop_size))

            # concatenate rois to one blob
            rois = K.concatenate(rois, axis=0)

            # reorder rois back to original order
            indices = K.concatenate(ordered_indices, axis=0)
            rois = tf.scatter_nd(indices, rois, K.cast(K.shape(rois), 'int64'))

            return rois

        roi_batch = tf.map_fn(_roi_align,
                              elems=[boxes, scores, fpn],
                              dtype=K.floatx(),
                              parallel_iterations=self.parallel_iterations)

        if time_distributed:
            roi_shape = tf.shape(roi_batch)
            new_roi_shape = [boxes_shape[0], boxes_shape[1]] + \
                            [roi_shape[i] for i in range(1, K.ndim(roi_batch))]
            roi_batch = tf.reshape(roi_batch, new_roi_shape)

        return roi_batch
Esempio n. 29
0
def ComplexBN(input_centred,
              Vrr,
              Vii,
              Vri,
              beta,
              gamma_rr,
              gamma_ri,
              gamma_ii,
              scale=True,
              center=True,
              layernorm=False,
              axis=-1):
    """Complex Batch Normalization
    
    Arguments:
        input_centred -- input data
        Vrr -- Real component of covariance matrix V
        Vii -- Imaginary component of covariance matrix V
        Vri -- Non-diagonal component of covariance matrix V
        beta -- Lernable shift parameter beta
        gamma_rr -- Scaling parameter gamma - rr component of 2x2 matrix
        gamma_ri -- Scaling parameter gamma - ri component of 2x2 matrix
        gamma_ii -- Scaling parameter gamma - ii component of 2x2 matrix
    
    Keyword Arguments:
        scale {bool} {bool} -- Standardization of input  (default: {True})
        center {bool} -- Mean-shift correction (default: {True})
        layernorm {bool} -- Normalization (default: {False})
        axis {int} -- Axis for Standardization (default: {-1})
    
    Raises:
        ValueError: Dimonsional mismatch
    
    Returns:
        Batch-Normalized Input
    """

    ndim = K.ndim(input_centred)
    input_dim = K.shape(input_centred)[axis] // 2
    if scale:
        gamma_broadcast_shape = [1] * ndim
        gamma_broadcast_shape[axis] = input_dim
    if center:
        broadcast_beta_shape = [1] * ndim
        broadcast_beta_shape[axis] = input_dim * 2

    if scale:
        standardized_output = complex_standardization(input_centred,
                                                      Vrr,
                                                      Vii,
                                                      Vri,
                                                      layernorm,
                                                      axis=axis)

        # Now we perform th scaling and Shifting of the normalized x using
        # the scaling parameter
        #           [  gamma_rr gamma_ri  ]
        #   Gamma = [  gamma_ri gamma_ii  ]
        # and the shifting parameter
        #    Beta = [beta_real beta_imag].T
        # where:
        # x_real_BN = gamma_rr * x_real_normed +
        #             gamma_ri * x_imag_normed + beta_real
        # x_imag_BN = gamma_ri * x_real_normed +
        #             gamma_ii * x_imag_normed + beta_imag

        broadcast_gamma_rr = K.reshape(gamma_rr, gamma_broadcast_shape)
        broadcast_gamma_ri = K.reshape(gamma_ri, gamma_broadcast_shape)
        broadcast_gamma_ii = K.reshape(gamma_ii, gamma_broadcast_shape)

        cat_gamma_4_real = K.concatenate(
            [broadcast_gamma_rr, broadcast_gamma_ii], axis=axis)
        cat_gamma_4_imag = K.concatenate(
            [broadcast_gamma_ri, broadcast_gamma_ri], axis=axis)
        if (axis == 1 and ndim != 3) or ndim == 2:
            centred_real = standardized_output[:, :input_dim]
            centred_imag = standardized_output[:, input_dim:]
        elif ndim == 3:
            centred_real = standardized_output[:, :, :input_dim]
            centred_imag = standardized_output[:, :, input_dim:]
        elif axis == -1 and ndim == 4:
            centred_real = standardized_output[:, :, :, :input_dim]
            centred_imag = standardized_output[:, :, :, input_dim:]
        elif axis == -1 and ndim == 5:
            centred_real = standardized_output[:, :, :, :, :input_dim]
            centred_imag = standardized_output[:, :, :, :, input_dim:]
        else:
            raise ValueError(
                'Incorrect Batchnorm combination of axis and dimensions. axis'
                ' should be either 1 or -1. '
                'axis: ' + str(axis) + '; ndim: ' + str(ndim) + '.')
        rolled_standardized_output = K.concatenate(
            [centred_imag, centred_real], axis=axis)
        if center:
            broadcast_beta = K.reshape(beta, broadcast_beta_shape)
            return cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output + broadcast_beta
        else:
            return cat_gamma_4_real * standardized_output + cat_gamma_4_imag * rolled_standardized_output
    else:
        if center:
            broadcast_beta = K.reshape(beta, broadcast_beta_shape)
            return input_centred + broadcast_beta
        else:
            return input_centred
Esempio n. 30
0
    def call(self,
             inputs,
             initial_state=None,
             initial_readout=None,
             ground_truth=None,
             mask=None,
             training=None):
        # input shape: `(samples, time (padded with zeros), input_dim)`
        # note that the .build() method of subclasses MUST define
        # self.input_spec and self.state_spec with complete input shapes.
        if type(mask) is list:
            mask = mask[0]
        if self.model is None:
            raise Exception('Empty RecurrentModel.')
        num_req_states = self.num_states
        if self.readout:
            num_actual_states = num_req_states - 1
        else:
            num_actual_states = num_req_states
        if type(inputs) is list:
            inputs_list = inputs[:]
            inputs = inputs_list.pop(0)
            initial_states = inputs_list[:num_actual_states]
            if len(initial_states) > 0:
                if self._is_optional_input_placeholder(initial_states[0]):
                    initial_states = self.get_initial_state(inputs)
            inputs_list = inputs_list[num_actual_states:]
            if self.readout:
                initial_readout = inputs_list.pop(0)
                if self.teacher_force:
                    ground_truth = inputs_list.pop()
        else:
            if initial_state is not None:
                if not isinstance(initial_state, (list, tuple)):
                    initial_states = [initial_state]
                else:
                    initial_states = list(initial_state)
                if self._is_optional_input_placeholder(initial_states[0]):
                    initial_states = self.get_initial_state(inputs)

            elif self.stateful:
                initial_states = self.states
            else:
                initial_states = self.get_initial_state(inputs)
        if self.readout:
            if initial_readout is None or self._is_optional_input_placeholder(
                    initial_readout):
                output_shape = K.int_shape(_to_list((self.model.output))[0])
                output_ndim = len(output_shape)
                input_ndim = K.ndim(inputs)
                initial_readout = K.zeros_like(inputs)
                slices = [slice(None)] + [0] * (input_ndim - 1)
                initial_readout = initial_readout[slices]  # (batch_size,)
                initial_readout = K.reshape(initial_readout,
                                            (-1, ) + (1, ) * (output_ndim - 1))
                initial_readout = K.tile(initial_readout,
                                         (1, ) + tuple(output_shape[1:]))
            initial_states.append(initial_readout)
            if self.teacher_force:
                if ground_truth is None or self._is_optional_input_placeholder(
                        ground_truth):
                    raise Exception(
                        'ground_truth must be provided for RecurrentModel with teacher_force=True.'
                    )
                if K.backend() == 'tensorflow':
                    with tf.control_dependencies(None):
                        counter = K.zeros((1, ))
                else:
                    counter = K.zeros((1, ))
                counter = K.cast(counter, 'int32')
                initial_states.insert(-1, counter)
                initial_states[-2]
                initial_states.insert(-1, ground_truth)
                num_req_states += 2
        if len(initial_states) != num_req_states:
            raise ValueError('Layer requires ' + str(num_req_states) +
                             ' states but was passed ' +
                             str(len(initial_states)) + ' initial states.')
        input_shape = K.int_shape(inputs)
        if self.unroll and input_shape[1] is None:
            raise ValueError('Cannot unroll a RNN if the '
                             'time dimension is undefined. \n'
                             '- If using a Sequential model, '
                             'specify the time dimension by passing '
                             'an `input_shape` or `batch_input_shape` '
                             'argument to your first layer. If your '
                             'first layer is an Embedding, you can '
                             'also use the `input_length` argument.\n'
                             '- If using the functional API, specify '
                             'the time dimension by passing a `shape` '
                             'or `batch_shape` argument to your Input layer.')
        preprocessed_input = self.preprocess_input(inputs, training=None)
        constants = self.get_constants(inputs, training=None)
        if self.decode:
            initial_states.insert(0, inputs)
            preprocessed_input = K.zeros((1, self.output_length, 1))
            input_length = self.output_length
        else:
            input_length = input_shape[1]
        if self.uses_learning_phase:
            with learning_phase_scope(0):
                last_output_test, outputs_test, states_test, updates = rnn(
                    self.step,
                    preprocessed_input,
                    initial_states,
                    go_backwards=self.go_backwards,
                    mask=mask,
                    constants=constants,
                    unroll=self.unroll,
                    input_length=input_length)
            with learning_phase_scope(1):
                last_output_train, outputs_train, states_train, updates = rnn(
                    self.step,
                    preprocessed_input,
                    initial_states,
                    go_backwards=self.go_backwards,
                    mask=mask,
                    constants=constants,
                    unroll=self.unroll,
                    input_length=input_length)

            last_output = K.in_train_phase(last_output_train,
                                           last_output_test,
                                           training=training)
            outputs = K.in_train_phase(outputs_train,
                                       outputs_test,
                                       training=training)
            states = []
            for state_train, state_test in zip(states_train, states_test):
                states.append(
                    K.in_train_phase(state_train,
                                     state_test,
                                     training=training))

        else:
            last_output, outputs, states, updates = rnn(
                self.step,
                preprocessed_input,
                initial_states,
                go_backwards=self.go_backwards,
                mask=mask,
                constants=constants,
                unroll=self.unroll,
                input_length=input_length)
        states = list(states)
        if self.decode:
            states.pop(0)
        if self.readout:
            states.pop()
            if self.teacher_force:
                states.pop()
                states.pop()
        if len(updates) > 0:
            self.add_update(updates)
        if self.stateful:
            updates = []
            for i in range(len(states)):
                updates.append((self.states[i], states[i]))
            self.add_update(updates, inputs)

        # Properly set learning phase
        if 0 < self.dropout + self.recurrent_dropout:
            last_output._uses_learning_phase = True
            outputs._uses_learning_phase = True

        if self.return_sequences:
            y = outputs
        else:
            y = last_output
        if self.return_states:
            return [y] + states
        else:
            return y