def _generate_dropout_mask(ones, rate, training=None, count=1):
    def dropped_inputs():
        return K.dropout(ones, rate)

    if count > 1:
        return [
            K.in_train_phase(dropped_inputs, ones, training=training)
            for _ in range(count)
        ]
    return K.in_train_phase(dropped_inputs, ones, training=training)
    def call(self, inputs, training=None):
        if training is None:
            training = K.learning_phase()
        if self.use_mc_dropout:
            training = True

        def drop_inputs():
            return K.dropout(inputs, self.unit_dropout)

        if 0. < self.unit_dropout < 1.:
            inputs = K.in_train_phase(drop_inputs, inputs, training=training)

        #kernel dropout
        ones = array_ops.ones_like(self.kernel)

        def dropped_weight_connections():
            return K.dropout(ones,
                             self.kernel_dropout) * (1 - self.kernel_dropout)

        if 0. < self.kernel_dropout < 1.:
            kern_dp_mask = K.in_train_phase(dropped_weight_connections,
                                            ones,
                                            training=training)
        else:
            kern_dp_mask = ones

        rank = len(inputs.shape)
        if rank > 2:
            # Broadcasting is required for the inputs.
            outputs = standard_ops.tensordot(inputs,
                                             self.kernel * kern_dp_mask,
                                             [[rank - 1], [0]])
            # Reshape the output back to the original ndim of the input.
            if not context.executing_eagerly():
                shape = inputs.shape.as_list()
                output_shape = shape[:-1] + [self.units]
                outputs.set_shape(output_shape)
        else:
            inputs = math_ops.cast(inputs, self._compute_dtype)
            if K.is_sparse(inputs):
                outputs = sparse_ops.sparse_tensor_dense_matmul(
                    inputs, self.kernel * kern_dp_mask)
            else:
                outputs = gen_math_ops.mat_mul(inputs,
                                               self.kernel * kern_dp_mask)
        if self.use_bias:
            outputs = nn.bias_add(outputs, self.bias)
        if self.activation is not None:
            return self.activation(outputs)  # pylint: disable=not-callable
        return outputs
Example #3
0
    def call(self, x, training=None):
        if len(x) != 2:
            raise Exception('input layers must be a list: mean and logvar')
        if len(x[0].shape) != 2 or len(x[1].shape) != 2:
            raise Exception(
                'input shape is not a vector [batchSize, latentSize]')
        mean = x[0]
        logvar = x[1]
        if mean.shape[0].value == None or logvar.shape[0].value == None:
            return mean + 0 * logvar

        if self.reg is not None:
            latent_loss = -0.5 * (1 + logvar - K.square(mean) - K.exp(logvar))
            latent_loss = K.sum(latent_loss, axis=-1)
            latent_loss = K.mean(latent_loss, axis=0)
            latent_loss = self.beta * latent_loss
            self.add_loss(latent_loss, x)

        def reparameterization_trick():
            epsilon = K.random_normal(shape=logvar.shape, mean=0., stddev=1.)
            stddev = K.exp(logvar * 0.5)
            return mean + stddev * epsilon

        return K.in_train_phase(reparameterization_trick,
                                mean + 0 * logvar,
                                training=training)
Example #4
0
  def call(self, inputs, training=None):

    def noised():
      return inputs + K.random_normal(
          shape=array_ops.shape(inputs), mean=0., stddev=self.stddev)

    return K.in_train_phase(noised, inputs, training=training)
Example #5
0
  def call(self, inputs, training=None):
    if 0. < self.rate < 1.:
      noise_shape = self._get_noise_shape(inputs)

      def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed):  # pylint: disable=missing-docstring
        alpha = 1.6732632423543772848170429916717
        scale = 1.0507009873554804934193349852946
        alpha_p = -alpha * scale

        kept_idx = math_ops.greater_equal(
            K.random_uniform(noise_shape, seed=seed), rate)
        kept_idx = math_ops.cast(kept_idx, K.floatx())

        # Get affine transformation params
        a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5
        b = -a * alpha_p * rate

        # Apply mask
        x = inputs * kept_idx + alpha_p * (1 - kept_idx)

        # Do affine transformation
        return a * x + b

      return K.in_train_phase(dropped_inputs, inputs, training=training)
    return inputs
Example #6
0
    def call(self, x, training=None):
        if len(x) != 2:
            raise Exception('input layers must be a list: mean and logvar')
        if len(x[0].shape) != 2 or len(x[1].shape) != 2:
            raise Exception(
                'input shape is not a vector [batchSize, latentSize]')

        mean = x[0]
        logvar = x[1]

        # trick to allow setting batch at train/eval time
        if mean.shape[0] is None or logvar.shape[0] is None:
            return mean + 0 * logvar  # Keras needs the *0 so the gradinent is not None

        # kl divergence:
        latent_loss = -0.5 * (1 + logvar - K.square(mean) - K.exp(logvar))
        latent_loss = K.sum(latent_loss, axis=-1)  # sum over latent dimension
        latent_loss = K.mean(latent_loss, axis=0)  # avg over batch

        # use beta to force less usage of vector space:
        # set beta
        latent_loss = 1.0 * latent_loss
        self.add_loss(latent_loss)

        #self.add_loss(latent_loss, x)

        def reparameterization_trick():
            epsilon = K.random_normal(shape=logvar.shape, mean=0., stddev=1.)
            stddev = K.exp(logvar * 0.5)
            return mean + stddev * epsilon

        return K.in_train_phase(
            reparameterization_trick, mean + 0 * logvar, training=training
        )  # TODO figure out why this is not working in the specified tf version???
Example #7
0
 def call(self, inputs, **kwargs):
     main_input, embedding_matrix = inputs
     input_shape_tensor = K.shape(main_input)
     last_input_dim = K.int_shape(main_input)[-1]
     emb_input_dim, emb_output_dim = K.int_shape(embedding_matrix)
     projected = K.dot(K.reshape(main_input, (-1, last_input_dim)),
                       self.embedding_weights['projection'])
     if self.add_biases:
         projected = K.bias_add(projected,
                                self.embedding_weights['biases'],
                                data_format='channels_last')
     if 0 < self.projection_dropout < 1:
         projected = K.in_train_phase(
             lambda: K.dropout(projected, self.projection_dropout),
             projected,
             training=kwargs.get('training'))
     attention = K.dot(projected, K.transpose(embedding_matrix))
     if self.scaled_attention:
         # scaled dot-product attention, described in
         # "Attention is all you need" (https://arxiv.org/abs/1706.03762)
         sqrt_d = K.constant(math.sqrt(emb_output_dim), dtype=K.floatx())
         attention = attention / sqrt_d
     result = K.reshape(
         self.activation(attention),
         (input_shape_tensor[0], input_shape_tensor[1], emb_input_dim))
     return result
Example #8
0
  def call(self, inputs, training=None):

    def noised():
      return inputs + K.random_normal(
          shape=array_ops.shape(inputs), mean=0., stddev=self.stddev)

    return K.in_train_phase(noised, inputs, training=training)
Example #9
0
    def call(self, inputs, training=None):
        if 0. < self.rate < 1.:
            noise_shape = self._get_noise_shape(inputs)

            def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed):  # pylint: disable=missing-docstring
                alpha = 1.6732632423543772848170429916717
                scale = 1.0507009873554804934193349852946
                alpha_p = -alpha * scale

                kept_idx = math_ops.greater_equal(
                    K.random_uniform(noise_shape, seed=seed), rate)
                kept_idx = math_ops.cast(kept_idx, inputs.dtype)

                # Get affine transformation params
                a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5
                b = -a * alpha_p * rate

                # Apply mask
                x = inputs * kept_idx + alpha_p * (1 - kept_idx)

                # Do affine transformation
                return a * x + b

            return K.in_train_phase(dropped_inputs, inputs, training=training)
        return inputs
Example #10
0
 def update_boost_strength(self):
     """
     Update boost strength using given strength factor during training.
     """
     factor = K.in_train_phase(self.boost_strength_factor, 1.0)
     self.add_update(
         self.boost_strength.assign(self.boost_strength * factor,
                                    read_value=False))
    def call(self, inputs, training=None):
        stddev = K.sqrt(K.mean(K.square(inputs))) * 0.05

        def noised():
            return inputs + K.random_normal(
                shape=array_ops.shape(inputs), mean=0., stddev=stddev)

        return K.in_train_phase(noised, noised, training=training)
Example #12
0
    def apply_dropout_if_needed(self, attention_softmax, training=None):
        if 0.0 < self.dropout < 1.0:
            def dropped_softmax():
                return K.dropout(attention_softmax, self.dropout)

            return K.in_train_phase(dropped_softmax, attention_softmax,
                                    training=training)
        return attention_softmax
Example #13
0
 def call(self, x, mask=None):
     if 0. < self.rate < 1.:
         noise_shape = self._get_noise_shape(x)
         if self.permanent:
             x = K.dropout(x, self.rate)
         else:       
             x = K.in_train_phase(K.dropout(x, self.rate), x)
     return x
    def call(self, inputs, training=None):
        def noised():
            return inputs + tf.random.normal(array_ops.shape(inputs),
                                             mean=0.0,
                                             stddev=self.stddev,
                                             dtype=inputs.dtype,
                                             seed=None)

        return K.in_train_phase(noised, inputs, training=training)
Example #15
0
  def call(self, inputs, training=None):

    def noised():
      return inputs + backend.random_normal(
          shape=array_ops.shape(inputs),
          mean=0.,
          stddev=self.stddev,
          dtype=inputs.dtype)

    return backend.in_train_phase(noised, inputs, training=training)
Example #16
0
    def call(self, inputs, training=None):
        if 0 < self.rate < 1:

            def noised():
                stddev = np.sqrt(self.rate / (1.0 - self.rate))
                return inputs * K.random_normal(
                    shape=array_ops.shape(inputs), mean=1.0, stddev=stddev)

            return K.in_train_phase(noised, inputs, training=training)
        return inputs
Example #17
0
  def call(self, inputs, training=None):
    if 0 < self.rate < 1:

      def noised():
        stddev = np.sqrt(self.rate / (1.0 - self.rate))
        return inputs * K.random_normal(
            shape=array_ops.shape(inputs), mean=1.0, stddev=stddev)

      return K.in_train_phase(noised, inputs, training=training)
    return inputs
Example #18
0
    def build(self):
        # Embedding-layer to transform input into 3D-space.
        input_embedding = Embedding(self.data_sequence.vocab_size(), self.embedding_dim)

        # Inputs
        encoder_inputs = Input(shape=(None,))
        encoder_inputs_emb = input_embedding(encoder_inputs)

        # Encoder LSTM
        encoder = LSTM(self.lstm_hidden_dim, return_state=True)
        encoder_outputs, state_h, state_c = encoder(encoder_inputs_emb)
        state = [state_h, state_c]  # state will be used to initialize the decoder

        # Start vars (emulates a constant input)
        def constant(input_batch, size):
            batch_size = K.shape(input_batch)[0]
            return K.tile(K.ones((1, size)), (batch_size, 1))

        decoder_in = Lambda(constant, arguments={'size': self.embedding_dim})(encoder_inputs_emb)  # "start word"

        # Definition of further layers to be used in the model (decoder and mapping to vocab-sized vector)
        decoder_lstm = LSTM(self.lstm_hidden_dim, return_sequences=False, return_state=True)
        decoder_dense = Dense(self.data_sequence.vocab_size(), activation='softmax')

        chars = []  # Container for single results during the loop
        for i in range(self.max_decoder_length):
            # Reshape necessary to match LSTMs interface, cell state will be reintroduced in the next iteration
            decoder_in = Reshape((1, self.embedding_dim))(decoder_in)
            decoder_in, hidden_state, cell_state = decoder_lstm(decoder_in, initial_state=state)
            state = [hidden_state, cell_state]

            # Mapping
            decoder_out = decoder_dense(decoder_in)

            # Reshaping and storing for later concatenation
            char = Reshape((1, self.data_sequence.vocab_size()))(decoder_out)
            chars.append(char)

            # Teacher forcing. During training the original input will be used as input to the decoder
            decoder_in_train = Lambda(lambda x, ii: x[:, -ii], arguments={'ii': i+1})(encoder_inputs_emb)
            decoder_in = Lambda(lambda x, y: K.in_train_phase(y, x), arguments={'y': decoder_in_train})(decoder_in)

        # Single results are joined together (axis 1 vanishes)
        decoded_seq = Concatenate(axis=1)(chars)

        self.model = Model(encoder_inputs, decoded_seq, name="enc_dec")
        self.model.compile(optimizer='adam', loss='categorical_crossentropy')
        self.model.summary()

        try:
            file_name = 'enc_dec_model'
            plot_model(self.model, to_file=f'{file_name}.png', show_shapes=True)
            print(f"Model built. Saved {file_name}.png\n")
        except (ImportError, FileNotFoundError):
            print(f"Skipping plotting of model due to missing dependencies.")
Example #19
0
    def call(self, inputs, mask=None):
        def sparse():
            # number of dimensions in input might be < |k|. account for that
            actual_k = tf.minimum(K.shape(inputs)[-1] - 1, self.k)
            # multiply all values greater than the k smallest with 1, the rest with 0
            kth_smallest = tf.sort(inputs)[...,
                                           K.shape(inputs)[-1] - 1 - actual_k]
            return inputs * K.cast(K.greater(inputs, kth_smallest[:, None]),
                                   K.floatx())

        return K.in_train_phase(sparse, inputs)
Example #20
0
    def call(self, inputs, training=None, **kwargs):
        inputs = super().call(inputs, **kwargs)
        k = K.in_test_phase(x=self.k_inference, alt=self.k, training=training)
        kwinners = compute_kwinners(
            x=inputs,
            k=k,
            duty_cycles=self.duty_cycles,
            boost_strength=self.boost_strength,
        )

        duty_cycles = K.in_train_phase(
            lambda: self.compute_duty_cycle(kwinners),
            self.duty_cycles,
            training=training,
        )
        self.add_update(self.duty_cycles.assign(duty_cycles, read_value=False))
        increment = K.in_train_phase(K.shape(inputs)[0], 0, training=training)
        self.add_update(
            self.learning_iterations.assign_add(increment, read_value=False))

        return kwinners
Example #21
0
    def _generate_dropout_mask(self, inputs, training=None):
        if 0 < self.dropout < 1:
            ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))

            def dropped_inputs():
                return K.dropout(ones, self.dropout)

            self._dropout_mask = [
                K.in_train_phase(dropped_inputs, ones, training=training)
                for _ in range(4)
            ]
        else:
            self._dropout_mask = None
Example #22
0
  def get_constants(self, inputs, training=None):
    constants = []
    if self.implementation == 0 and 0 < self.dropout < 1:
      ones = K.zeros_like(inputs)
      ones = K.sum(ones, axis=1)
      ones += 1

      def dropped_inputs():
        return K.dropout(ones, self.dropout)

      dp_mask = [
          K.in_train_phase(dropped_inputs, ones, training=training)
          for _ in range(4)
      ]
      constants.append(dp_mask)
    else:
      constants.append([K.cast_to_floatx(1.) for _ in range(4)])

    if 0 < self.recurrent_dropout < 1:
      depthwise_shape = list(self.depthwise_kernel_shape)
      pointwise_shape = list(self.pointwise_kernel_shape)
      ones = K.zeros_like(inputs)
      ones = K.sum(ones, axis=1)
      ones = self.input_conv(ones, K.zeros(depthwise_shape), 
             K.zeros(pointwise_shape), padding=self.padding)
      ones += 1.

      def dropped_inputs():  # pylint: disable=function-redefined
        return K.dropout(ones, self.recurrent_dropout)

      rec_dp_mask = [
          K.in_train_phase(dropped_inputs, ones, training=training)
          for _ in range(4)
      ]
      constants.append(rec_dp_mask)
    else:
      constants.append([K.cast_to_floatx(1.) for _ in range(4)])
    return constants
    def call(self, inputs, training=None):
        def drop_connect():
            keep_prob = 1.0 - self.drop_connect_rate

            # Compute drop_connect tensor
            batch_size = tf.shape(inputs)[0]
            random_tensor = keep_prob
            random_tensor += tf.random.uniform([batch_size, 1, 1, 1],
                                               dtype=inputs.dtype)
            binary_tensor = tf.floor(random_tensor)
            output = tf.divide(inputs, keep_prob) * binary_tensor
            return output

        return K.in_train_phase(drop_connect, inputs, training=training)
Example #24
0
    def call(self, inputs, **kwargs):

        is_training = kwargs.get('training', False)

        if self.dropconnect_prob > 0.0:

            def dropconnected():
                return dropconnect(self.kernel, self.dropconnect_prob)

            # Apply dropconnect if in training
            # Fails when overwriting kernel, hence the "DC"
            self.kernelDC = K.in_train_phase(dropconnected,
                                             self.kernel,
                                             training=is_training)
        else:
            self.kernelDC = self.kernel

        # Apply kernel to inputs
        # Note: This part came from Dense()
        rank = len(inputs.shape)
        if rank > 2:
            # Broadcasting is required for the inputs.
            outputs = standard_ops.tensordot(inputs, self.kernelDC,
                                             [[rank - 1], [0]])
            # Reshape the output back to the original ndim of the input.
            if not context.executing_eagerly():
                shape = inputs.shape.as_list()
                output_shape = shape[:-1] + [self.units]
                outputs.set_shape(output_shape)
        else:
            inputs = math_ops.cast(inputs, self._compute_dtype)
            if K.is_sparse(inputs):
                outputs = sparse_ops.sparse_tensor_dense_matmul(
                    inputs, self.kernelDC)
            else:
                outputs = gen_math_ops.mat_mul(inputs, self.kernelDC)

        # Add bias
        if self.use_bias:
            outputs = nn.bias_add(outputs, self.bias)

        # Apply scaling factor
        if self.scale:
            outputs = self.scaler(outputs)

        # Apply activation function
        if self.activation is not None:
            outputs = self.activation(outputs)  # pylint: disable=not-callable

        return outputs
Example #25
0
def _time_distributed_dense(x,
                            w,
                            b=None,
                            dropout=None,
                            input_dim=None,
                            output_dim=None,
                            timesteps=None,
                            training=None):
    """Apply `y . w + b` for every temporal slice y of x.

    # Arguments
        x: input tensor.
        w: weight matrix.
        b: optional bias vector.
        dropout: wether to apply dropout (same dropout mask
            for every temporal slice of the input).
        input_dim: integer; optional dimensionality of the input.
        output_dim: integer; optional dimensionality of the output.
        timesteps: integer; optional number of timesteps.
        training: training phase tensor or boolean.

    # Returns
        Output tensor.
    """
    if not input_dim:
        input_dim = K.shape(x)[2]
    if not timesteps:
        timesteps = K.shape(x)[1]
    if not output_dim:
        output_dim = K.int_shape(w)[1]

    if dropout is not None and 0. < dropout < 1.:
        # apply the same dropout pattern at every timestep
        ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim)))
        dropout_matrix = K.dropout(ones, dropout)
        expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps)
        x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training)

    # collapse time dimension and batch dimension together
    x = K.reshape(x, (-1, input_dim))
    x = K.dot(x, w)
    if b is not None:
        x = K.bias_add(x, b)
    # reshape to 3D tensor
    if K.backend() == 'tensorflow':
        x = K.reshape(x, K.stack([-1, timesteps, output_dim]))
        x.set_shape([None, None, output_dim])
    else:
        x = K.reshape(x, (-1, timesteps, output_dim))
    return x
Example #26
0
    def _generate_recurrent_dropout_mask(self, inputs, training=None):
        if 0 < self.recurrent_dropout < 1:
            ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
            ones = K.tile(ones, (1, self.units))

            def dropped_inputs():
                return K.dropout(ones, self.dropout)

            self._recurrent_dropout_mask = [
                K.in_train_phase(dropped_inputs, ones, training=training)
                for _ in range(4)
            ]
        else:
            self._recurrent_dropout_mask = None
Example #27
0
 def dot_product_attention(self, x, mask=None, dropout=0.1, training=None):
     q, k, v = x
     logits = tf.matmul(q, k, transpose_b=True)  # [bs, 8, len, len]
     if self.bias:
         logits += self.b
     if mask is not None:  # [bs, len]
         mask = tf.expand_dims(mask, axis=1)
         mask = tf.expand_dims(mask, axis=1)  # [bs,1,1,len]
         logits = self.mask_logits(logits, mask)
     weights = tf.nn.softmax(logits, name="attention_weights")
     weights = K.in_train_phase(K.dropout(weights, dropout),
                                weights,
                                training=training)
     x = tf.matmul(weights, v)
     return x
Example #28
0
def center(inputs, moving_mean, w, h, c, instance_norm=False):
    if instance_norm:
        x_t = tf.transpose(inputs, (0, 3, 1, 2))
        x_flat = tf.reshape(x_t, (-1, c, w * h))
        # (bs, c, w*h)
        m = tf.reduce_mean(x_flat, axis=2, keepdims=True)
        # (bs, c, 1)
    else:
        x_t = tf.transpose(inputs, (3, 0, 1, 2))
        x_flat = tf.reshape(x_t, (c, -1))
        # (c, bs*w*h)
        m = tf.reduce_mean(x_flat, axis=1, keepdims=True)
        m = K.in_train_phase(m, moving_mean)
        # (c, 1)
    f = x_flat - m
    return m, f
Example #29
0
    def apply_dropout_if_needed(self, attention_softmax, training=None):
        """
        apply dropout after attention softmax if desired
        :param attention_softmax:
        :param training:
        :return:
        """
        if 0.0 < self.dropout < 1.0:

            def dropped_softmax():
                return K.dropout(attention_softmax, self.dropout)

            return K.in_train_phase(dropped_softmax,
                                    attention_softmax,
                                    training=training)
        return attention_softmax
Example #30
0
    def W_bar(self):
        # Spectrally Normalized Weight
        W_mat = K.permute_dimensions(
            self.kernel, (3, 2, 0, 1))  # (h, w, i, o) => (o, i, h, w)
        W_mat = K.reshape(W_mat, [K.shape(W_mat)[0], -1])  # (o, i * h * w)

        if not self.Ip >= 1:
            raise ValueError(
                "The number of power iterations should be positive integer")

        _u = self.u
        _v = None

        for _ in range(self.Ip):
            _v = _l2normalize(K.dot(_u, W_mat))
            _u = _l2normalize(K.dot(_v, K.transpose(W_mat)))

        sigma = K.sum(K.dot(_u, W_mat) * _v)

        K.update(self.u, K.in_train_phase(_u, self.u))
        return self.kernel / sigma
Example #31
0
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        reduction_axes = list(range(0, len(input_shape)))

        if self.axis is not None:
            del reduction_axes[self.axis]

        del reduction_axes[0]

        mean = K.mean(inputs, reduction_axes, keepdims=True)
        stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
        normed = (inputs - mean) / stddev

        def noised():
            eps = K.random_uniform(shape=[1], maxval=self.alpha)
            return inputs + K.random_normal(
                shape=K.shape(inputs), mean=0., stddev=eps)

        get_noised = K.in_train_phase(noised, normed, training=training)

        retrived = stddev * get_noised + mean
        return retrived
Example #32
0
    def call(self, x, training=None):
        def outputs_inference():
            # Apply truncation trick according to cutoff.
            num_layers = K.int_shape(x)[1]

            if self.cutoff is not None:
                beta = Ke.where(
                    np.arange(num_layers)[np.newaxis, :,
                                          np.newaxis] < self.cutoff,
                    self.psi *
                    np.ones(shape=(1, num_layers, 1), dtype=np.float32),
                    np.ones(shape=(1, num_layers, 1), dtype=np.float32))  #?
            else:
                beta = np.ones(shape=(1, num_layers, 1), dtype=np.float32)

            return self.moving_mean + (x - self.moving_mean) * beta  #?

        # Update moving average.
        mean = K.mean(x[:, 0], axis=0)  #?
        x_moving_mean = K.moving_average_update(self.moving_mean, mean,
                                                self.momentum)  #? add_update?

        # Apply truncation trick according to cutoff.
        num_layers = K.int_shape(x)[1]

        if self.cutoff is not None:
            beta = Ke.where(
                np.arange(num_layers)[np.newaxis, :, np.newaxis] < self.cutoff,
                self.psi * np.ones(shape=(1, num_layers, 1), dtype=np.float32),
                np.ones(shape=(1, num_layers, 1), dtype=np.float32))  #?
        else:
            beta = np.ones(shape=(1, num_layers, 1), dtype=np.float32)

        outputs = x_moving_mean + (x - self.moving_mean) * beta  #?

        return K.in_train_phase(outputs, outputs_inference, training=training)
    def call(self, inputs, training=None):
        class_labels = K.squeeze(inputs[1], axis=1)
        inputs = inputs[0]
        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])

        param_broadcast = [1] * len(input_shape)
        param_broadcast[self.axis] = input_shape[self.axis]
        param_broadcast[0] = K.shape(inputs)[0]
        if self.scale:
            broadcast_gamma = K.reshape(K.gather(self.gamma, class_labels),
                                        param_broadcast)
        else:
            broadcast_gamma = None

        if self.center:
            broadcast_beta = K.reshape(K.gather(self.beta, class_labels),
                                       param_broadcast)
        else:
            broadcast_beta = None

        normed, mean, variance = K.normalize_batch_in_training(
            inputs,
            gamma=None,
            beta=None,
            reduction_axes=reduction_axes,
            epsilon=self.epsilon)

        if training in {0, False}:
            return normed
        else:
            self.add_update([
                K.moving_average_update(self.moving_mean, mean, self.momentum),
                K.moving_average_update(self.moving_variance, variance,
                                        self.momentum)
            ], inputs)

            def normalize_inference():
                if needs_broadcasting:
                    # In this case we must explictly broadcast all parameters.
                    broadcast_moving_mean = K.reshape(self.moving_mean,
                                                      broadcast_shape)
                    broadcast_moving_variance = K.reshape(
                        self.moving_variance, broadcast_shape)
                    return K.batch_normalization(inputs,
                                                 broadcast_moving_mean,
                                                 broadcast_moving_variance,
                                                 beta=None,
                                                 gamma=None,
                                                 epsilon=self.epsilon)
                else:
                    return K.batch_normalization(inputs,
                                                 self.moving_mean,
                                                 self.moving_variance,
                                                 beta=None,
                                                 gamma=None,
                                                 epsilon=self.epsilon)

        # Pick the normalized form corresponding to the training phase.
        out = K.in_train_phase(normed, normalize_inference, training=training)
        return out * broadcast_gamma + broadcast_beta
    def call(self, inputs, training=None):
        _, w, h, c = K.int_shape(inputs)
        bs = K.shape(inputs)[0]

        m, f = utils.center(inputs, self.moving_mean, self.instance_norm)
        get_inv_sqrt = utils.get_decomposition(self.decomposition, bs,
                                               self.group, self.instance_norm,
                                               self.iter_num, self.epsilon,
                                               self.device)

        def train():
            ff_aprs = utils.get_group_cov(f, self.group, self.m_per_group,
                                          self.instance_norm, bs, w, h, c)

            if self.instance_norm:
                ff_aprs = tf.transpose(ff_aprs, (1, 0, 2, 3))
                ff_aprs = (1 - self.epsilon) * ff_aprs + tf.expand_dims(
                    tf.expand_dims(tf.eye(self.m_per_group) * self.epsilon, 0),
                    0)
            else:
                ff_aprs = (1 - self.epsilon) * ff_aprs + tf.expand_dims(
                    tf.eye(self.m_per_group) * self.epsilon, 0)

            whitten_matrix = get_inv_sqrt(ff_aprs, self.m_per_group)[1]

            self.add_update([
                K.moving_average_update(self.moving_mean, m, self.momentum),
                K.moving_average_update(
                    self.moving_matrix,
                    whitten_matrix if '_wm' in self.decomposition else ff_aprs,
                    self.momentum)
            ], inputs)

            if self.renorm:
                l, l_inv = get_inv_sqrt(ff_aprs, self.m_per_group)
                ff_mov = (1 - self.epsilon) * self.moving_matrix + tf.eye(
                    self.m_per_group) * self.epsilon
                _, l_mov_inverse = get_inv_sqrt(ff_mov, self.m_per_group)
                l_ndiff = K.stop_gradient(l)
                return tf.matmul(tf.matmul(l_mov_inverse, l_ndiff), l_inv)

            return whitten_matrix

        def test():
            moving_matrix = (1 - self.epsilon) * self.moving_matrix + tf.eye(
                self.m_per_group) * self.epsilon
            if '_wm' in self.decomposition:
                return moving_matrix
            else:
                return get_inv_sqrt(moving_matrix, self.m_per_group)[1]

        if self.instance_norm == 1:
            inv_sqrt = train()
            f = tf.reshape(f, [-1, self.group, self.m_per_group, w * h])
            f_hat = tf.matmul(inv_sqrt, f)
            decorelated = K.reshape(f_hat, [bs, c, w, h])
            decorelated = tf.transpose(decorelated, [0, 2, 3, 1])
        else:
            inv_sqrt = K.in_train_phase(train, test)
            f = tf.reshape(f, [self.group, self.m_per_group, -1])
            f_hat = tf.matmul(inv_sqrt, f)
            decorelated = K.reshape(f_hat, [c, bs, w, h])
            decorelated = tf.transpose(decorelated, [1, 2, 3, 0])

        return decorelated