Esempio n. 1
0
    def call(self, inputs, training=None):
        if not self.trainable:
            training = False
        else:
            # The learning phase flag is a bool tensor (0 = test, 1 = train)
            training = K.learning_phase()

        if training is not False:
            K.update_add(self.iterations, 1)
            # compute current mean&var
            mini_mean, mini_variance = tf.nn.moments(inputs, axes=[0,1,2])
            # affine the inputs
            x = (inputs - self.steps_mean) / K.sqrt(self.steps_variance + self.epsilon)
            x = self.gamma * x + self.beta
            # update the moving params
            K.moving_average_update(self.moving_mean, mini_mean, self.momentum)
            K.moving_average_update(self.moving_variance, mini_variance, self.momentum)
            # update the short-term params under specific condition
            cond = K.equal(self.iterations % self.steps_per_update, 0)
            K.switch(cond, lambda: self.steps_mean*0, K.update_add(self.steps_mean, mini_mean))
            K.switch(cond, lambda: self.steps_variance*0, K.update_add(self.steps_variance, mini_mean))
        else:
            # affine
            scale = self.gamma / K.sqrt(self.moving_variance + self.epsilon)
            x = inputs * scale + (self.beta - self.moving_mean * scale)
        return x
Esempio n. 2
0
 def call(self, x):
     mean, var = tf.nn.moments(x, [0])
     self.add_update([
         K.moving_average_update(self.mu, mean, self._mu_l),
         K.moving_average_update(self.sigma, tf.sqrt(var), self._sigma_l)
     ], x)
     return (x - self.mu) / (self.sigma + self._eps)
Esempio n. 3
0
    def call(self, x_cat, mask=None):
        # For some reason, we have to concatenate vectors to feed them using "merge" in keras
        x_cat = self.epsilon + (
            1 - 2. * self.epsilon
        ) * x_cat  #K.clip(x_cat, self.epsilon, 1 - self.epsilon)  # Avoid NANs
        z = x_cat[:, :self.size]
        x = x_cat[:, self.size:]
        batch_size = K.cast(
            K.shape(x)[0],
            x.dtype)  # This is a node tensor, so we can't treat as integer
        div_n = Lambda(
            lambda v: v / batch_size
        )  # Dividing by batch size is an operation on unknown tensor

        # batch statistics
        px = K.expand_dims(K.mean(x, axis=0), 0)  # p(xi = 1)
        py = K.expand_dims(K.mean(z, axis=0), 1)  # mean of z_j
        V = div_n(K.dot(K.transpose(z), x))  # j i

        self.add_update([
            K.moving_average_update(self.Vr, V, self.momentum),
            K.moving_average_update(self.pxr, px, self.momentum),
            K.moving_average_update(self.pyr, py, self.momentum)
        ], x_cat)
        V = K.in_train_phase(V, self.Vr)
        px = K.in_train_phase(px, self.pxr)
        py = K.in_train_phase(py, self.pyr)
        eta1 = V / px
        eta0 = (py - V) / (1 - px)
        W = K.log(eta1) - K.log(1 - eta1) + K.log(1 - eta0) - K.log(eta0)
        out = K.log(px) - K.log(1. - px) + K.dot(z, W) + K.sum(
            K.log(1. - eta1) - K.log(1. - eta0), 0, keepdims=True)
        return K.sigmoid(out)
Esempio n. 4
0
        def training_phase():
            mean_batch = K.mean(mean_instance, axis=0, keepdims=True)
            variance_batch = K.mean(temp, axis=0,
                                    keepdims=True) - K.square(mean_batch)

            mean_batch_reshaped = K.flatten(mean_batch)
            variance_batch_reshaped = K.flatten(variance_batch)

            if K.backend() != 'cntk':
                sample_size = K.prod(
                    [K.shape(inputs)[axis] for axis in reduction_axes])
                sample_size = K.cast(sample_size, dtype=K.dtype(inputs))

                # sample variance - unbiased estimator of population variance
                variance_batch_reshaped *= sample_size / (sample_size -
                                                          (1.0 + self.epsilon))

            self.add_update([
                K.moving_average_update(self.moving_mean, mean_batch_reshaped,
                                        self.momentum),
                K.moving_average_update(self.moving_variance,
                                        variance_batch_reshaped, self.momentum)
            ], inputs)

            return normalize_func(mean_batch, variance_batch)
Esempio n. 5
0
    def call(self, inputs, training=None):
        inputs, spk_id = inputs
        spk_id = K.cast(K.flatten(spk_id)[0], 'int32')

        def normalize_inference():
            return K.normalize_batch_in_training(inputs,
                                                 self.gamma[spk_id],
                                                 self.beta[spk_id], [0, 1],
                                                 epsilon=self.epsilon)[0]

        normed_training, mean, variance = K.normalize_batch_in_training(
            inputs,
            self.gamma[spk_id],
            self.beta[spk_id], [0, 1],
            epsilon=self.epsilon)

        sample_size = K.shape(inputs)[1]
        sample_size = K.cast(sample_size, dtype=K.dtype(inputs))
        variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        self.add_update([
            K.moving_average_update(self.moving_mean, mean, self.momentum),
            K.moving_average_update(self.moving_variance, variance,
                                    self.momentum)
        ], inputs)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)
    def call(self, x, mask=None):
        if self.mode == 0 or self.mode == 2:
            assert self.built, 'Layer must be built before being called'
            input_shape = self.input_spec[0].shape

            reduction_axes = list(range(len(input_shape)))
            del reduction_axes[self.axis]
            broadcast_shape = [1] * len(input_shape)
            broadcast_shape[self.axis] = input_shape[self.axis]

            if self.mode == 2:
                x_normed, mean, std = K.normalize_batch_in_training(
                    x, self.gamma, self.beta, reduction_axes,
                    epsilon=self.epsilon)
            else:
                # mode 0
                if self.called_with not in {None, x} and False:
                    raise Exception('You are attempting to share a '
                                    'same `BatchNormalization` layer across '
                                    'different data flows. '
                                    'This is not possible. '
                                    'You should use `mode=2` in '
                                    '`BatchNormalization`, which has '
                                    'a similar behavior but is shareable '
                                    '(see docs for a description of '
                                    'the behavior).')
                self.called_with = x
                x_normed, mean, std = K.normalize_batch_in_training(
                    x, self.gamma, self.beta, reduction_axes,
                    epsilon=self.epsilon)

                self.updates = [K.moving_average_update(self.running_mean, mean, self.momentum),
                                K.moving_average_update(self.running_std, std, self.momentum)]

                if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                    x_normed_running = K.batch_normalization(
                        x, self.running_mean, self.running_std,
                        self.beta, self.gamma,
                        epsilon=self.epsilon)
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_std, broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        x, broadcast_running_mean, broadcast_running_std,
                        broadcast_beta, broadcast_gamma,
                        epsilon=self.epsilon)

                # pick the normalized form of x corresponding to the training phase
                x_normed = K.in_train_phase(x_normed, x_normed_running)

        elif self.mode == 1:
            # sample-wise normalization
            m = K.mean(x, axis=-1, keepdims=True)
            std = K.sqrt(K.var(x, axis=-1, keepdims=True) + self.epsilon)
            x_normed = (x - m) / (std + self.epsilon)
            x_normed = self.gamma * x_normed + self.beta
        return x_normed
Esempio n. 7
0
    def call(self, x, mask=None):
        if self.mode == 0 or self.mode == 2:
            assert self.built, 'Layer must be built before being called'
            input_shape = K.int_shape(x)

            reduction_axes = list(range(len(input_shape)))
            del reduction_axes[self.axis]
            broadcast_shape = [1] * len(input_shape)
            broadcast_shape[self.axis] = input_shape[self.axis]

            x_normed, mean, std = K.normalize_batch_in_training(
                x, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)

            if self.mode == 0:
                self.add_update([
                    K.moving_average_update(self.running_mean, mean,
                                            self.momentum),
                    K.moving_average_update(self.running_std, std,
                                            self.momentum)
                ], x)

                if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                    x_normed_running = K.batch_normalization(
                        x,
                        self.running_mean,
                        self.running_std,
                        self.beta,
                        self.gamma,
                        epsilon=self.epsilon)
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_std,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        x,
                        broadcast_running_mean,
                        broadcast_running_std,
                        broadcast_beta,
                        broadcast_gamma,
                        epsilon=self.epsilon)

                # pick the normalized form of x corresponding to the training phase
                x_normed = K.in_train_phase(x_normed, x_normed_running)

        elif self.mode == 1:
            # sample-wise normalization
            m = K.mean(x, axis=-1, keepdims=True)
            std = K.sqrt(K.var(x, axis=-1, keepdims=True) + self.epsilon)
            x_normed = (x - m) / (std + self.epsilon)
            x_normed = self.gamma * x_normed + self.beta
        else:
            return None
        return x_normed
    def call(self, inputs, training=None, **kwargs):

        G = self.groups

        # transpose:[ba,h,w,c] -> [bs,c,h,w]
        if self.axis in {-1, 3}:
            inputs = K.permute_dimensions(inputs, (0, 3, 1, 2))

        input_shape = K.int_shape(inputs)
        N, C, H, W = input_shape
        inputs = K.reshape(inputs, (-1, G, C // G, H, W))
        # inputs.assign_sub()

        # compute group-channel mean & variance
        gn_mean = K.mean(inputs, axis=[2, 3, 4], keepdims=True)
        gn_variance = K.var(inputs, axis=[2, 3, 4], keepdims=True)

        # compute group-normalization in different state
        def gn_inference():
            # when in test phase, just return moving_mean & moving_var
            mean, variance = self.moving_mean, self.moving_variance
            outputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))
            outputs = K.reshape(outputs, [-1, C, H, W]) * self.gamma + self.beta
            # transpose: [bs,c,h,w] -> [ba,h,w,c]
            if self.axis in {-1, 3}:
                outputs = K.permute_dimensions(outputs, (0, 2, 3, 1))

            return outputs

        if training in {0, False}:
            return gn_inference()

        outputs = (inputs - gn_mean) / (K.sqrt(gn_variance + self.epsilon))
        outputs = K.reshape(outputs, [-1, C, H, W]) * self.gamma + self.beta

        # transpose: [bs,c,h,w] -> [ba,h,w,c]
        if self.axis in {-1, 3}:
            outputs = K.permute_dimensions(outputs, (0, 2, 3, 1))

        self.add_update([K.moving_average_update(self.moving_mean,
                                                 gn_mean,
                                                 self.momentum),
                         K.moving_average_update(self.moving_variance,
                                                 gn_variance,
                                                 self.momentum)],
                        inputs)

        # print("moving_mean shape : ",K.int_shape(self.moving_mean))
        # print("moving_mean: ",K.eval(self.moving_mean))
        # print("moving_variance shape: ",K.int_shape(self.moving_variance))
        # print("moving_variance: ",K.eval(self.moving_variance))

        return K.in_train_phase(outputs,
                                gn_inference,
                                training=training)
Esempio n. 9
0
    def call(self, x, mask=None):
        output = K.conv2d(x, self.W, strides=self.subsample,
                          border_mode=self.border_mode,
                          dim_ordering=self.dim_ordering,
                          filter_shape=self.W_shape)

        # added for batch normalization
        input_shape = K.int_shape(output)
        axis = 1

        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[axis] = input_shape[axis]

        output_normed, mean, std = K.normalize_batch_in_training(
            output, self.gamma, self.beta, reduction_axes,
            epsilon=self.epsilon)

        self.add_update([K.moving_average_update(self.running_mean, mean, self.momentum),
                         K.moving_average_update(self.running_std, std, self.momentum)], output)

        if sorted(reduction_axes) == range(K.ndim(output))[:-1]:
            output_normed_running = K.batch_normalization(
                output, self.running_mean, self.running_std,
                self.beta, self.gamma,
                epsilon=self.epsilon)
        else:
            # need broadcasting
            broadcast_running_mean = K.reshape(self.running_mean, broadcast_shape)
            broadcast_running_std = K.reshape(self.running_std, broadcast_shape)
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            output_normed_running = K.batch_normalization(
                output, broadcast_running_mean, broadcast_running_std,
                broadcast_beta, broadcast_gamma,
                epsilon=self.epsilon)

        # pick the normalized form of output corresponding to the training phase
        output_normed = K.in_train_phase(output_normed, output_normed_running)


        if self.bias:
            if self.dim_ordering == 'th':
                output_normed += K.reshape(self.b, (1, self.nb_filter, 1, 1))
            elif self.dim_ordering == 'tf':
                output_normed += K.reshape(self.b, (1, 1, 1, self.nb_filter))
            else:
                raise ValueError('Invalid dim_ordering:', self.dim_ordering)
        output = self.activation(output_normed)
        return output
Esempio n. 10
0
        def normed_training():
            mean_bn = K.mean(inputs, axis=reduction_axes_bn,keepdims=True)
            variance_bn = K.var(inputs, axis=reduction_axes_bn,keepdims=True)
            mean = [mean_in, mean_ln, mean_bn]
            variance = [variance_in, variance_ln, variance_bn]

            # If the learning is either dynamic, or set to training:
            self.add_update([K.moving_average_update(self.moving_mean,
                                                     K.reshape(mean_bn,(input_shape[self.axis],)),
                                                     self.momentum),
                             K.moving_average_update(self.moving_variance,
                                                     K.reshape(variance_bn,(input_shape[self.axis],)),
                                                     self.momentum)],
                            inputs)
            return norm(mean, variance)
Esempio n. 11
0
    def call(self, x_cat, mask=None):
        # For some reason, we have to concatenate vectors to feed them using "merge" in keras
        z = x_cat[:, :self.size]
        x = x_cat[:, self.size:]
        batch_size = K.cast(
            K.shape(x)[0],
            x.dtype)  # This is a node tensor, so we can't treat as integer
        div_n = Lambda(
            lambda v: v / batch_size
        )  # Dividing by batch size is an operation on unknown tensor

        # batch statistics
        self.mi = K.expand_dims(K.mean(x, axis=0), 0)  # mean of x_i
        self.mj = K.expand_dims(K.mean(z, axis=0), 1)  # mean of z_j
        self.vj = K.expand_dims(K.var(z, axis=0) + self.epsilon,
                                1)  # sigma_j^2
        self.vi = K.expand_dims(K.var(x, axis=0) + self.epsilon,
                                0)  # sigma_i^2

        #CHANGE BACK
        #self.V = div_n(K.dot(K.transpose(z), x))
        self.V = div_n(
            K.dot(K.transpose(z - K.transpose(self.mj)), x - self.mi))  # j i

        self.add_update([
            K.moving_average_update(self.Vr, self.V, self.momentum),
            K.moving_average_update(self.mir, self.mi, self.momentum),
            K.moving_average_update(self.mjr, self.mj, self.momentum),
            K.moving_average_update(self.vjr, self.vj, self.momentum),
            K.moving_average_update(self.vir, self.vi, self.momentum)
        ], x_cat)
        V = K.in_train_phase(self.V, self.Vr)
        mi = K.in_train_phase(self.mi, self.mir)
        mj = K.in_train_phase(self.mj, self.mjr)
        vj = K.in_train_phase(self.vj, self.vjr)
        vi = K.in_train_phase(self.vi, self.vir)

        #CHANGE BACK
        #rho = (V - mi * mj) / K.sqrt(vi * vj)
        rho = V / K.sqrt(vi * vj)
        Q = rho / (1 - K.square(rho))
        self.R = K.sum(rho * Q, axis=0, keepdims=True)
        Q = Q / (1 + self.R)
        if self.return_r:
            return self.R
        else:
            return mi + K.sqrt(vi) * K.dot(
                K.transpose((K.transpose(z) - mj) / K.sqrt(vj)), Q)
Esempio n. 12
0
 def inject(self):
     """添加更新算子到model.metrics_updates。
     """
     self.initialize()
     for w1, w2 in zip(self.ema_weights, self.model.weights):
         op = K.moving_average_update(w1, w2, self.momentum)
         self.model.metrics_updates.append(op)
 def update_erm():
     normed_training, mean, variance = K.normalize_batch_in_training(
         x=inputs, beta=None, gamma=None, reduction_axes=reduction_axes)
     self.add_update(
         [K.moving_average_update(self.values, mean, self.momentum)],
         inputs=inputs)
     return self.values
Esempio n. 14
0
    def call(self, inputs, training=None):
        if training is None:
            training = bk.learning_phase()
        training = bk.get_value(training)

        if training:
            bk.moving_average_update(self.moving_min, bk.min(inputs, axis=0),
                                     self.momentum)
            bk.moving_average_update(self.moving_max, bk.max(inputs, axis=0),
                                     self.momentum)

        scale = (self.max_val - self.min_val) / (
            self.moving_max - self.moving_min + self.epsilon)
        output = bk.clip((inputs - self.moving_min) * scale + self.min_val,
                         self.min_val, self.max_val)
        return output
Esempio n. 15
0
    def call(self, inputs, training=None):
        x = inputs
        assert not isinstance(x, list)

        # Compute the minibatch statistics
        mean, var = self._moments(x)
        sigma = K.sqrt(var + self.epsilon)

        # If in training phase set rmax, dmax large so that we use the moving
        # averages to do the normalization
        rmax = K.in_train_phase(self.rmax, K.constant(1e5), training)
        dmax = K.in_train_phase(self.dmax, K.constant(1e5), training)

        # Compute the corrections based on rmax, dmax
        r = K.stop_gradient(
            self._clip(sigma / self.moving_sigma, 1. / rmax, rmax))
        d = K.stop_gradient(
            self._clip((mean - self.moving_mean) / self.moving_sigma, -dmax,
                       dmax))

        # Actually do the normalization and the rescaling
        xnorm = ((x - mean) / sigma) * r + d
        y = self.gamma * xnorm + self.beta

        # Add the moving average updates
        self.add_update([
            K.moving_average_update(self.moving_mean, mean, self.momentum),
            K.moving_average_update(self.moving_sigma, sigma, self.momentum)
        ], x)

        # Add the r, d updates
        rmax_prog = K.minimum(1., self.steps / self.rmax_dur)
        dmax_prog = K.minimum(1., self.steps / self.dmax_dur)
        self.add_update([
            K.update_add(self.steps, 1),
            K.update(self.rmax,
                     self.rmax_0 + rmax_prog * (self.rmax_inf - self.rmax_0)),
            K.update(self.dmax,
                     self.dmax_0 + dmax_prog * (self.dmax_inf - self.dmax_0))
        ])

        # Fix the output's uses learning phase
        y._uses_learning_phase = rmax._uses_learning_phase

        return y
 def inject(self):
     """添加更新算子到model.metrics_updates。
     """
     self.initialize()
     for w1, w2 in zip(self.ema_weights, self.model.weights):
         op = K.moving_average_update(w1, w2, self.momentum)
         #self.model.metrics_updates.append(op) # 在 keras 2.2.4 有效
         if not hasattr(self.model, '_other_metrics'):
             self.model._other_metrics = []
         self.model._other_metrics.append(op)
Esempio n. 17
0
    def call(self, inputs, training=None):
        if len(inputs) == 3:
            params, trainable_params, x = inputs
            params = self.merge_params(params, trainable_params)
        elif len(inputs) == 2:
            params, x = inputs
        else:
            raise ValueError("Wrong number of inputs")

        offset = 0
        for layer in self.layers:
            layer_params = params[:, offset:offset + layer["num_params"]]
            offset += layer["num_params"]

            if layer["type"] in ["standard-batchnorm", "batch-renorm"]:
                x = K.stack(x, 0)
                self.mean, self.variance = tf.nn.moments(x, [0, 1, 2])

                if training:
                    sample_size = K.prod(
                        [K.shape(x)[axis] for axis in [0, 1, 2]])
                    sample_size = K.cast(sample_size, dtype='float32')
                    unbiased_variance = self.variance * sample_size / (
                        sample_size - (1.0 + layer["epsilon"]))

                    self.add_update([
                        K.moving_average_update(
                            self.moving_means[layer["name"]], self.mean,
                            layer["momentum"]),
                        K.moving_average_update(
                            self.moving_vars[layer["name"]], unbiased_variance,
                            layer["momentum"]),
                    ], inputs)

            x = [
                self.evaluate_layer(layer, layer_params[i], x[i], training)
                for i in range(self.batch_size)
            ]

        output = K.stack(x, 0)
        output._uses_learning_phase = True
        return output
Esempio n. 18
0
    def call(self, x_cat, mask=None):
        # For some reason, we have to concatenate vectors to feed them using "merge" in keras
        z = x_cat[:, :self.size]
        x = K.clip(x_cat[:, self.size:], self.epsilon, 1. - self.epsilon)
        batch_size = K.cast(
            K.shape(x)[0],
            x.dtype)  # This is a node tensor, so we can't treat as integer
        div_n = Lambda(
            lambda v: v / batch_size
        )  # Dividing by batch size is an operation on unknown tensor

        # batch statistics
        pi = K.expand_dims(K.mean(x, axis=0), 0)  # p(xi = 1)
        mj = K.expand_dims(K.mean(z, axis=0), 1)  # mean of z_j
        vj = K.expand_dims(K.mean(K.square(z), axis=0),
                           1)  # expectation of z^2
        V = div_n(K.dot(K.transpose(z), x))  # j i
        S = div_n(K.dot(K.transpose(K.square(z)), x))  # j i

        self.add_update([
            K.moving_average_update(self.Vr, V, self.momentum),
            K.moving_average_update(self.Sr, S, self.momentum),
            K.moving_average_update(self.pir, pi, self.momentum),
            K.moving_average_update(self.mjr, mj, self.momentum),
            K.moving_average_update(self.vjr, vj, self.momentum)
        ], x_cat)
        V = K.in_train_phase(V, self.Vr)
        S = K.in_train_phase(S, self.Sr)
        pi = K.in_train_phase(pi, self.pir)
        mj = K.in_train_phase(mj, self.mjr)
        vj = K.in_train_phase(vj, self.vjr)

        mu0, mu1, sig0, sig1 = self.get_mean_sig(mj, vj, pi, V, S)

        out = (K.log(pi) - K.log(1. - pi) -
               0.5 * K.sum(K.log(sig1) - K.log(sig0), 0) +
               0.5 * K.sum(K.square(mu0) / sig0 - K.square(mu1) / sig1, 0) +
               K.dot(z, mu1 / sig1 - mu0 / sig0) +
               0.5 * K.dot(K.square(z), 1. / sig0 - 1. / sig1))
        return K.sigmoid(out)
Esempio n. 19
0
    def _update_embedding(self, x, y, seg_indices, seg_embeddings):
        dtype = self.embedding.dtype
        delta_embeddings = (1 - self.target_momentum) * (y - seg_embeddings)
        tmp_embedding, tmp_cnt = self._sum_seg_embeddings(
            seg_indices, delta_embeddings)

        bk.update_add(
            self.embedding,
            tmp_embedding / (tmp_cnt + bk.cast(0 == tmp_cnt, dtype=dtype)))
        bk.update_add(self.update_cnt, tmp_cnt)

        if self.mask_zero:
            min_val = bk.min(x + bk.constant(self.val_inf, dtype=dtype) *
                             bk.cast(0 == x, dtype),
                             axis=0)
            max_val = bk.max(x + bk.constant(-self.val_inf, dtype=dtype) *
                             bk.cast(0 == x, dtype),
                             axis=0)
        else:
            min_val, max_val = bk.min(x, axis=0), bk.max(x, axis=0)
        bk.moving_average_update(self.moving_min, min_val, self.val_momentum)
        bk.moving_average_update(self.moving_max, max_val, self.val_momentum)
Esempio n. 20
0
    def call(self, inputs, training=None):
        x = inputs
        assert not isinstance(x, list)

        # Do the normalization and the rescaling
        xnorm = K.batch_normalization(x,
                                      self.moving_mean,
                                      self.moving_variance,
                                      self.beta,
                                      self.gamma,
                                      epsilon=self.epsilon)

        # Compute and update the minibatch statistics
        if self.update_stats:
            mean, var = self._moments(x, axes=range(len(K.int_shape(x)) - 1))
            self.add_update([
                K.moving_average_update(self.moving_mean, mean, self.momentum),
                K.moving_average_update(self.moving_variance, var,
                                        self.momentum)
            ], x)

        return xnorm
Esempio n. 21
0
        def batch_norm(inputs, gamma, beta, dims, ind):
            """ Normalize batch and update moving averages for mean and std
            Input:
              inputs: (batchsize, n_points, k, n_features * 2) - edge_features
              gamma: weight - gamma for batch normalization
              beta: weight - beta for batch normalization
              dims: list - dimensions along which to normalize
              ind: int - indicating which weights to use
            Returns:
             During training:
              normed: (batchsize, n_points, k, n_features * 2) - normalized
                            batch of data using actual batch for normalization
             Else:
              normed_moving: same, but using the updated average values
            """

            # Calculate normalized data, mean and std for batch
            normed, batch_mean, batch_var = K.normalize_batch_in_training(
                                                x=inputs,
                                                gamma=gamma,
                                                beta=beta,
                                                reduction_axes=dims)

            # Update the moving averages
            self.add_update([
                K.moving_average_update(self.moving_mean[ind], batch_mean, 0.9),
                K.moving_average_update(self.moving_var[ind], batch_var, 0.9)])

            # Calculate normalization using the averages
            normed_moving = K.batch_normalization(
                                                x=inputs,
                                                mean=self.moving_mean[ind],
                                                var=self.moving_var[ind],
                                                beta=beta,
                                                gamma=gamma)

            # If training return normed, else normed_moving
            return K.in_train_phase(normed, normed_moving)
Esempio n. 22
0
    def call(self, x_cat, mask=None):
        # For some reason, we have to concatenate vectors to feed them using "merge" in keras
        z = x_cat[:, :self.size]
        x = x_cat[:, self.size:]
        batch_size = K.cast(
            K.shape(x)[0],
            x.dtype)  # This is a node tensor, so we can't treat as integer
        div_n = Lambda(
            lambda v: v / batch_size
        )  # Dividing by batch size is an operation on unknown tensor

        # batch statistics
        pi = K.expand_dims(
            K.clip(K.mean(x, axis=0), self.epsilon, 1. - self.epsilon),
            0)  # p(xi = 1)
        mj = K.expand_dims(K.mean(z, axis=0), 1)  # mean of z_j
        vj = K.expand_dims(K.var(z, axis=0) + self.epsilon, 1)  # sigma_j^2
        V = div_n(K.dot(K.transpose(z), x))  # j i

        self.add_update([
            K.moving_average_update(self.Vr, V, self.momentum),
            K.moving_average_update(self.pir, pi, self.momentum),
            K.moving_average_update(self.mjr, mj, self.momentum),
            K.moving_average_update(self.vjr, vj, self.momentum)
        ], x_cat)
        V = K.in_train_phase(V, self.Vr)
        pi = K.in_train_phase(pi, self.pir)
        mj = K.in_train_phase(mj, self.mjr)
        vj = K.in_train_phase(vj, self.vjr)

        mu_diff = (V - mj * pi) / (
            pi * (1 - pi))  # difference between mu_xi=1^j - mu_xi=0^j
        mu_mean = 0.5 * (V / pi + (mj - V) / (1 - pi))  # average of means
        out = K.log(pi) - K.log(1. - pi) + K.dot(z, mu_diff / vj) - K.sum(
            mu_diff * mu_mean / vj, 0, keepdims=True)
        return K.sigmoid(out)
        def train():
            ff_apr = ktf.matmul(f, f, transpose_b=True) / (
                ktf.cast(bs * w * h, ktf.float32) - 1.)
            if self.decomposition in ['pca-cor', 'zca-cor']:
                dinv = ktf.diag(ktf.sqrt(ktf.diag_part(ff_apr)))
                ff_apr = ktf.matmul(ktf.matmul(dinv, ff_apr),
                                    ktf.matrix_inverse(dinv),
                                    transpose_b=True)
            self.add_update([
                K.moving_average_update(self.moving_mean, m, self.momentum),
                K.moving_average_update(self.moving_cov, ff_apr, self.momentum)
            ], inputs)
            ff_apr_shrinked = (
                1 - self.epsilon) * ff_apr + ktf.eye(c) * self.epsilon

            if self.renorm:
                l, l_inv = get_inv_sqrt(ff_apr_shrinked)
                ff_mov = (1 - self.epsilon
                          ) * self.moving_cov + ktf.eye(c) * self.epsilon
                _, l_mov_inverse = get_inv_sqrt(ff_mov)
                l_ndiff = K.stop_gradient(l)
                return ktf.matmul(ktf.matmul(l_mov_inverse, l_ndiff), l_inv)

            return get_inv_sqrt(ff_apr_shrinked)[1]
Esempio n. 24
0
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]

        # inference
        def normalize_inference():
            return inputs - self.moving_mean

        if training in {0, False}:
            return normalize_inference()

        mean = K.mean(inputs, axis=reduction_axes)
        normed_training = inputs - mean

        self.add_update(
            K.moving_average_update(self.moving_mean, mean, self.momentum),
            inputs)

        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)
Esempio n. 25
0
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        ndim = len(input_shape)
        reduction_axes = list(range(ndim))
        del reduction_axes[self.axis]
        input_dim = input_shape[self.axis] // 4
        mu = K.mean(inputs, axis=reduction_axes)
        broadcast_mu_shape = [1] * len(input_shape)
        broadcast_mu_shape[self.axis] = input_shape[self.axis]
        broadcast_mu = K.reshape(mu, broadcast_mu_shape)
        if self.center:
            input_centred = inputs - broadcast_mu
        else:
            input_centred = inputs
        centred_squared = input_centred ** 2
        if (self.axis == 1 and ndim != 3) or ndim == 2:
            centred_squared_r = centred_squared[:, :input_dim]
            centred_squared_i = centred_squared[:, input_dim:input_dim*2]
            centred_squared_j = centred_squared[:, input_dim*2:input_dim*3]
            centred_squared_k = centred_squared[:, input_dim*3:]
            centred_r = input_centred[:, :input_dim]
            centred_i = input_centred[:, input_dim:input_dim*2]
            centred_j = input_centred[:, input_dim*2:input_dim*3]
            centred_k = input_centred[:, input_dim*3:]
        elif ndim == 3:
            centred_squared_r = centred_squared[:, :, :input_dim]
            centred_squared_i = centred_squared[:, :, input_dim:input_dim*2]
            centred_squared_j = centred_squared[:, :, input_dim*2:input_dim*3]
            centred_squared_k = centred_squared[:, :, input_dim*3:]
            centred_r = input_centred[:, :, :input_dim]
            centred_i = input_centred[:, :, input_dim:input_dim*2]
            centred_j = input_centred[:, :, input_dim*2:input_dim*3]
            centred_k = input_centred[:, :, input_dim*3:]
        elif self.axis == -1 and ndim == 4:
            centred_squared_r = centred_squared[:, :, :, :input_dim]
            centred_squared_i = centred_squared[:, :, :, input_dim:input_dim*2]
            centred_squared_j = centred_squared[:, :, :, input_dim*2:input_dim*3]
            centred_squared_k = centred_squared[:, :, :, input_dim*3:]
            centred_r = input_centred[:, :, :, :input_dim]
            centred_i = input_centred[:, :, :, input_dim:input_dim*2]
            centred_j = input_centred[:, :, :, input_dim*2:input_dim*3]
            centred_k = input_centred[:, :, :, input_dim*3:]
        elif self.axis == -1 and ndim == 5:
            centred_squared_r = centred_squared[:, :, :, :, :input_dim]
            centred_squared_i = centred_squared[:, :, :, :, input_dim:input_dim*2]
            centred_squared_j = centred_squared[:, :, :, :, input_dim*2:input_dim*3]
            centred_squared_k = centred_squared[:, :, :, :, input_dim*3:]
            centred_r = input_centred[:, :, :, :, :input_dim]
            centred_i = input_centred[:, :, :, :, input_dim:input_dim*2]
            centred_j = input_centred[:, :, :, :, input_dim*2:input_dim*3]
            centred_k = input_centred[:, :, :, :, input_dim*3:]
        else:
            raise ValueError(
                'Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. '
                'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.'
            )
        if self.scale:
            Vrr = K.mean(
                centred_squared_r,
                axis=reduction_axes
            ) + self.epsilon
            Vii = K.mean(
                centred_squared_i,
                axis=reduction_axes
            ) + self.epsilon
            Vjj = K.mean(
                centred_squared_j,
                axis=reduction_axes
            ) + self.epsilon
            Vkk = K.mean(
                centred_squared_k,
                axis=reduction_axes
            ) + self.epsilon
            Vri = K.mean(
                centred_r * centred_i,
                axis=reduction_axes,
            )
            Vrj = K.mean(
                centred_r * centred_j,
                axis=reduction_axes,
            )
            Vrk = K.mean(
                centred_r * centred_k,
                axis=reduction_axes,
            )
            Vij = K.mean(
                centred_i * centred_j,
                axis=reduction_axes,
            )
            Vik = K.mean(
                centred_i * centred_k,
                axis=reduction_axes,
            )
            Vjk = K.mean(
                centred_j * centred_k,
                axis=reduction_axes,
            )
        elif self.center:
            Vrr = None
            Vii = None
            Vjj = None
            Vkk = None
            Vri = None
            Vrj = None
            Vrk = None
            Vij = None
            Vik = None
            Vjk = None
        else:
            raise ValueError('Error. Both scale and center in batchnorm are set to False.')

        input_bn = QuaternionBN(
            input_centred, 
            Vrr, Vri, Vrj, Vrk, Vii, 
            Vij, Vik, Vjj, Vjk, Vkk,
            self.beta, 
            self.gamma_rr, self.gamma_ri, 
            self.gamma_rj, self.gamma_rk, 
            self.gamma_ii, self.gamma_ij, 
            self.gamma_ik, self.gamma_jj, 
            self.gamma_jk, self.gamma_kk, 
            self.scale, self.center,
            axis=self.axis
        )
        if training in {0, False}:
            return input_bn
        else:
            update_list = []
            if self.center:
                update_list.append(K.moving_average_update(self.moving_mean, mu, self.momentum))
            if self.scale:
                update_list.append(K.moving_average_update(self.moving_Vrr, Vrr, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vii, Vii, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vjj, Vjj, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vkk, Vkk, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vri, Vri, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vrj, Vrj, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vrk, Vrk, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vij, Vij, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vik, Vik, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vjk, Vjk, self.momentum))
            self.add_update(update_list, inputs)

            def normalize_inference():
                if self.center:
                    inference_centred = inputs - K.reshape(self.moving_mean, broadcast_mu_shape)
                else:
                    inference_centred = inputs
                return QuaternionBN(
                    inference_centred, 
                    self.moving_Vrr, self.moving_Vri, 
                    self.moving_Vrj, self.moving_Vrk,
                    self.moving_Vii, self.moving_Vij,
                    self.moving_Vik, self.moving_Vjj,
                    self.moving_Vjk, self.moving_Vkk,
                    self.beta, 
                    self.gamma_rr, self.gamma_ri, 
                    self.gamma_rj, self.gamma_rk, 
                    self.gamma_ii, self.gamma_ij, 
                    self.gamma_ik, self.gamma_jj, 
                    self.gamma_jk, self.gamma_kk, 
                    self.scale, self.center, axis=self.axis
                )

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(input_bn,
                                normalize_inference,
                                training=training)
Esempio n. 26
0
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        def normalize_inference():
            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(self.moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(self.moving_variance,
                                                      broadcast_shape)
                if self.center:
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                else:
                    broadcast_beta = None
                if self.scale:
                    broadcast_gamma = K.reshape(self.gamma,
                                                broadcast_shape)
                else:
                    broadcast_gamma = None
                return K.batch_normalization(
                    inputs,
                    broadcast_moving_mean,
                    broadcast_moving_variance,
                    broadcast_beta,
                    broadcast_gamma,
                    epsilon=self.epsilon)
            else:
                return K.batch_normalization(
                    inputs,
                    self.moving_mean,
                    self.moving_variance,
                    self.beta,
                    self.gamma,
                    epsilon=self.epsilon)

        # If the learning phase is *static* and set to inference:
        if training in {0, False}:
            return normalize_inference()

        # If the learning is either dynamic, or set to training:
        normed_training, mean, variance = K.normalize_batch_in_training(
            inputs, self.gamma, self.beta, reduction_axes,
            epsilon=self.epsilon)

        self.add_update([K.moving_average_update(self.moving_mean,
                                                 mean,
                                                 self.momentum),
                         K.moving_average_update(self.moving_variance,
                                                 variance,
                                                 self.momentum)],
                        inputs)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)
Esempio n. 27
0
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        def normalize_inference():
            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(self.moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(self.moving_variance,
                                                      broadcast_shape)
                if self.center:
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                else:
                    broadcast_beta = None
                if self.scale:
                    broadcast_gamma = K.reshape(self.gamma,
                                                broadcast_shape)
                else:
                    broadcast_gamma = None
                return tf.nn.batch_normalization(#K.batch_normalization(
                    inputs,
                    broadcast_moving_mean,
                    broadcast_moving_variance,
                    broadcast_beta,
                    broadcast_gamma,
                    #axis=self.axis,
                    self.epsilon)#epsilon=self.epsilon)
            else:
                return tf.nn.batch_normalization(#K.batch_normalization(
                    inputs,
                    self.moving_mean,
                    self.moving_variance,
                    self.beta,
                    self.gamma,
                    #axis=self.axis,
                    self.epsilon)#epsilon=self.epsilon)

        # If the learning phase is *static* and set to inference:
        if training in {0, False}:
            return normalize_inference()

        # If the learning is either dynamic, or set to training:
        normed_training, mean, variance = _regular_normalize_batch_in_training(#K.normalize_batch_in_training(
            inputs, self.gamma, self.beta, reduction_axes,
            epsilon=self.epsilon)

        if K.backend() != 'cntk':
            sample_size = K.prod([K.shape(inputs)[axis]
                                  for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs))

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        self.add_update([K.moving_average_update(self.moving_mean,
                                                 mean,
                                                 self.momentum),
                         K.moving_average_update(self.moving_variance,
                                                 variance,
                                                 self.momentum)],
                        inputs)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)
Esempio n. 28
0
    def call(self, x, mask=None):
        if self.mode == 0 or self.mode == 2:
            assert self.built, 'Layer must be built before being called'
            input_shape = K.int_shape(x)

            reduction_axes = list(range(len(input_shape)))
            del reduction_axes[self.axis]
            broadcast_shape = [1] * len(input_shape)
            broadcast_shape[self.axis] = input_shape[self.axis]

            mean_batch, var_batch = K.moments(x,
                                              reduction_axes,
                                              shift=None,
                                              keep_dims=False)
            std_batch = (K.sqrt(var_batch + self.epsilon))

            r_max_value = K.get_value(self.r_max)
            r = std_batch / (K.sqrt(self.running_std + self.epsilon))
            r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

            d_max_value = K.get_value(self.d_max)
            d = (mean_batch - self.running_mean) / K.sqrt(self.running_std +
                                                          self.epsilon)
            d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

            if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                x_normed_batch = (x - mean_batch) / std_batch
                x_normed = (x_normed_batch * r + d) * self.gamma + self.beta
            else:
                # need broadcasting
                broadcast_mean = K.reshape(mean_batch, broadcast_shape)
                broadcast_std = K.reshape(std_batch, broadcast_shape)
                broadcast_r = K.reshape(r, broadcast_shape)
                broadcast_d = K.reshape(d, broadcast_shape)
                broadcast_beta = K.reshape(self.beta, broadcast_shape)
                broadcast_gamma = K.reshape(self.gamma, broadcast_shape)

                x_normed_batch = (x - broadcast_mean) / broadcast_std
                x_normed = (x_normed_batch * broadcast_r +
                            broadcast_d) * broadcast_gamma + broadcast_beta

            # explicit update to moving mean and standard deviation
            self.add_update([
                K.moving_average_update(self.running_mean, mean_batch,
                                        self.momentum),
                K.moving_average_update(self.running_std, std_batch**2,
                                        self.momentum)
            ], x)

            # update r_max and d_max
            t_val = K.get_value(self.t)
            r_val = self.r_max_value / (
                1 + (self.r_max_value - 1) * np.exp(-t_val))
            d_val = self.d_max_value / (1 + (
                (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
            t_val += float(self.t_delta)

            self.add_update([
                K.update(self.r_max, r_val),
                K.update(self.d_max, d_val),
                K.update(self.t, t_val)
            ], x)

            if self.mode == 0:
                if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                    x_normed_running = K.batch_normalization(
                        x,
                        self.running_mean,
                        self.running_std,
                        self.beta,
                        self.gamma,
                        epsilon=self.epsilon)
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_std,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        x,
                        broadcast_running_mean,
                        broadcast_running_std,
                        broadcast_beta,
                        broadcast_gamma,
                        epsilon=self.epsilon)

                # pick the normalized form of x corresponding to the training phase
                # for batch renormalization, inference time remains same as batchnorm
                x_normed = K.in_train_phase(x_normed, x_normed_running)

        elif self.mode == 1:
            # sample-wise normalization
            m = K.mean(x, axis=self.axis, keepdims=True)
            std = K.sqrt(
                K.var(x, axis=self.axis, keepdims=True) + self.epsilon)
            x_normed_batch = (x - m) / (std + self.epsilon)

            r_max_value = K.get_value(self.r_max)
            r = std / (self.running_std + self.epsilon)
            r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

            d_max_value = K.get_value(self.d_max)
            d = (m - self.running_mean) / (self.running_std + self.epsilon)
            d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

            x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta

            # update r_max and d_max
            t_val = K.get_value(self.t)
            r_val = self.r_max_value / (
                1 + (self.r_max_value - 1) * np.exp(-t_val))
            d_val = self.d_max_value / (1 + (
                (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
            t_val += float(self.t_delta)

            self.add_update([
                K.update(self.r_max, r_val),
                K.update(self.d_max, d_val),
                K.update(self.t, t_val)
            ], x)

        return x_normed
Esempio n. 29
0
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        ndim = len(input_shape)
        reduction_axes = list(range(ndim))
        del reduction_axes[self.axis]
        input_dim = input_shape[self.axis] // 2
        mu = K.mean(inputs, axis=reduction_axes)
        broadcast_mu_shape = [1] * len(input_shape)
        broadcast_mu_shape[self.axis] = input_shape[self.axis]
        broadcast_mu = K.reshape(mu, broadcast_mu_shape)
        if self.center:
            input_centred = inputs - broadcast_mu
        else:
            input_centred = inputs
        centred_squared = input_centred**2
        if (self.axis == 1 and ndim != 3) or ndim == 2:
            centred_squared_real = centred_squared[:, :input_dim]
            centred_squared_imag = centred_squared[:, input_dim:]
            centred_real = input_centred[:, :input_dim]
            centred_imag = input_centred[:, input_dim:]
        elif ndim == 3:
            centred_squared_real = centred_squared[:, :, :input_dim]
            centred_squared_imag = centred_squared[:, :, input_dim:]
            centred_real = input_centred[:, :, :input_dim]
            centred_imag = input_centred[:, :, input_dim:]
        elif self.axis == -1 and ndim == 4:
            centred_squared_real = centred_squared[:, :, :, :input_dim]
            centred_squared_imag = centred_squared[:, :, :, input_dim:]
            centred_real = input_centred[:, :, :, :input_dim]
            centred_imag = input_centred[:, :, :, input_dim:]
        elif self.axis == -1 and ndim == 5:
            centred_squared_real = centred_squared[:, :, :, :, :input_dim]
            centred_squared_imag = centred_squared[:, :, :, :, input_dim:]
            centred_real = input_centred[:, :, :, :, :input_dim]
            centred_imag = input_centred[:, :, :, :, input_dim:]
        else:
            raise ValueError(
                'Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. '
                'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.')
        if self.scale:
            Vrr = K.mean(centred_squared_real,
                         axis=reduction_axes) + self.epsilon
            Vii = K.mean(centred_squared_imag,
                         axis=reduction_axes) + self.epsilon
            # Vri contains the real and imaginary covariance for each feature map.
            Vri = K.mean(
                centred_real * centred_imag,
                axis=reduction_axes,
            )
        elif self.center:
            Vrr = None
            Vii = None
            Vri = None
        else:
            raise ValueError(
                'Error. Both scale and center in batchnorm are set to False.')

        input_bn = ComplexBN(input_centred,
                             Vrr,
                             Vii,
                             Vri,
                             self.beta,
                             self.gamma_rr,
                             self.gamma_ri,
                             self.gamma_ii,
                             self.scale,
                             self.center,
                             axis=self.axis)
        if training in {0, False}:
            return input_bn
        else:
            update_list = []
            if self.center:
                update_list.append(
                    K.moving_average_update(self.moving_mean, mu,
                                            self.momentum))
            if self.scale:
                update_list.append(
                    K.moving_average_update(self.moving_Vrr, Vrr,
                                            self.momentum))
                update_list.append(
                    K.moving_average_update(self.moving_Vii, Vii,
                                            self.momentum))
                update_list.append(
                    K.moving_average_update(self.moving_Vri, Vri,
                                            self.momentum))
            self.add_update(update_list, inputs)

            def normalize_inference():
                if self.center:
                    inference_centred = inputs - K.reshape(
                        self.moving_mean, broadcast_mu_shape)
                else:
                    inference_centred = inputs
                return ComplexBN(inference_centred,
                                 self.moving_Vrr,
                                 self.moving_Vii,
                                 self.moving_Vri,
                                 self.beta,
                                 self.gamma_rr,
                                 self.gamma_ri,
                                 self.gamma_ii,
                                 self.scale,
                                 self.center,
                                 axis=self.axis)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(input_bn,
                                normalize_inference,
                                training=training)
Esempio n. 30
0
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        def normalize_inference():
            def apply_mode_normalization_inference(moving_mean,
                                                   moving_variance, beta,
                                                   gamma):
                inputs_mul_gates_ = self.apply_gates(inputs, input_shape,
                                                     reduction_axes[1:])
                outputs = []
                for k_ in range(self.k):
                    outputs.append(
                        K.batch_normalization(inputs_mul_gates_[:, k_],
                                              moving_mean[k_],
                                              moving_variance[k_],
                                              beta / self.k,
                                              gamma,
                                              axis=self.axis,
                                              epsilon=self.epsilon))
                return K.sum(K.stack(outputs, axis=0), axis=0)

            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(self.moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(self.moving_variance,
                                                      broadcast_shape)
                if self.center:
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                else:
                    broadcast_beta = None
                if self.scale:
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                else:
                    broadcast_gamma = None
                return apply_mode_normalization_inference(
                    broadcast_moving_mean, broadcast_moving_variance,
                    broadcast_beta, broadcast_gamma)
            else:
                return apply_mode_normalization_inference(
                    self.moving_mean, self.moving_variance, self.beta,
                    self.gamma)

        # If the learning phase is *static* and set to inference:
        if training in {0, False}:
            return normalize_inference()

        inputs_mul_gates = self.apply_gates(inputs, input_shape,
                                            reduction_axes[1:])

        # training.
        mean_list, variance_list, normed_training_list = [], [], []
        norm_func = K.normalize_batch_in_training
        for k in range(self.k):
            normed_training, mean, variance = norm_func(inputs_mul_gates[:, k],
                                                        self.gamma,
                                                        self.beta / self.k,
                                                        reduction_axes,
                                                        epsilon=self.epsilon)
            normed_training_list.append(normed_training)
            mean_list.append(mean)
            variance_list.append(variance)

        mean = K.stack(mean_list, axis=0)
        variance = K.stack(variance_list, axis=0)
        normed_training = K.sum(normed_training_list, axis=0)

        if K.backend() != 'cntk':
            sample_size = K.prod(
                [K.shape(inputs)[axis] for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs))

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        self.add_update([
            K.moving_average_update(self.moving_mean, mean, self.momentum),
            K.moving_average_update(self.moving_variance, variance,
                                    self.momentum)
        ], inputs)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)
    def call(self, inputs, training=None):
        assert self.built, 'Layer must be built before being called'
        input_shape = K.int_shape(inputs)

        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        mean_batch, var_batch = _moments(inputs,
                                         reduction_axes,
                                         shift=None,
                                         keep_dims=False)
        std_batch = (K.sqrt(var_batch + self.epsilon))

        r = std_batch / (K.sqrt(self.running_variance + self.epsilon))
        r = K.stop_gradient(K.clip(r, 1 / self.r_max, self.r_max))

        d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance +
                                                      self.epsilon)
        d = K.stop_gradient(K.clip(d, -self.d_max, self.d_max))

        if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
            x_normed_batch = (inputs - mean_batch) / std_batch
            x_normed = (x_normed_batch * r + d) * self.gamma + self.beta
        else:
            # need broadcasting
            broadcast_mean = K.reshape(mean_batch, broadcast_shape)
            broadcast_std = K.reshape(std_batch, broadcast_shape)
            broadcast_r = K.reshape(r, broadcast_shape)
            broadcast_d = K.reshape(d, broadcast_shape)
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)

            x_normed_batch = (inputs - broadcast_mean) / broadcast_std
            x_normed = (x_normed_batch * broadcast_r +
                        broadcast_d) * broadcast_gamma + broadcast_beta

        # explicit update to moving mean and standard deviation
        mean_update = K.moving_average_update(self.running_mean, mean_batch,
                                              self.momentum)
        variance_update = K.moving_average_update(self.running_variance,
                                                  std_batch**2, self.momentum)
        self.add_update([mean_update, variance_update], inputs)

        # update r_max and d_max
        r_val = self.r_max_value / (1 +
                                    (self.r_max_value - 1) * K.exp(-self.t))
        d_val = (self.d_max_value /
                 (1 + ((self.d_max_value / 1e-3) - 1) * K.exp(-(2 * self.t))))

        self.add_update([
            K.update(self.r_max, r_val),
            K.update(self.d_max, d_val),
            K.update_add(self.t, self.t_delta_tensor)
        ], inputs)

        if training in {0, False}:
            return x_normed
        else:

            def normalize_inference():
                if sorted(reduction_axes) == list(range(K.ndim(inputs)))[:-1]:
                    x_normed_running = K.batch_normalization(
                        inputs,
                        self.running_mean,
                        self.running_variance,
                        self.beta,
                        self.gamma,
                        epsilon=self.epsilon)

                    return x_normed_running
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_variance,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        inputs,
                        broadcast_running_mean,
                        broadcast_running_std,
                        broadcast_beta,
                        broadcast_gamma,
                        epsilon=self.epsilon)

                    return x_normed_running

            # pick the normalized form of inputs corresponding to the training phase
            # for batch renormalization, inference time remains same as batchnorm
            x_normed = K.in_train_phase(x_normed,
                                        normalize_inference,
                                        training=training)

            return x_normed