def call(self, x, mask=None):

        assert self.built, 'Layer must be built before being called'
        input_shape = K.int_shape(x)

        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
            x_normed = K.batch_normalization(x,
                                             self.running_mean,
                                             self.running_std,
                                             self.beta,
                                             self.gamma,
                                             epsilon=self.epsilon)
        else:
            # need broadcasting
            broadcast_running_mean = K.reshape(self.running_mean,
                                               broadcast_shape)
            broadcast_running_std = K.reshape(self.running_std,
                                              broadcast_shape)
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            x_normed = K.batch_normalization(x,
                                             broadcast_running_mean,
                                             broadcast_running_std,
                                             broadcast_beta,
                                             broadcast_gamma,
                                             epsilon=self.epsilon)

        return x_normed
Esempio n. 2
0
 def normalize_inference():
     if needs_broadcasting:
         # In this case we must explicitly broadcast all parameters.
         broadcast_moving_mean = K.reshape(self.moving_mean,
                                           broadcast_shape)
         broadcast_moving_variance = K.reshape(self.moving_variance,
                                               broadcast_shape)
         if self.center:
             broadcast_beta = K.reshape(self.beta, broadcast_shape)
         else:
             broadcast_beta = None
         if self.scale:
             broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
         else:
             broadcast_gamma = None
         return K.batch_normalization(inputs,
                                      broadcast_moving_mean,
                                      broadcast_moving_variance,
                                      broadcast_beta,
                                      broadcast_gamma,
                                      epsilon=self.epsilon)
     else:
         return K.batch_normalization(inputs,
                                      self.moving_mean,
                                      self.moving_variance,
                                      self.beta,
                                      self.gamma,
                                      epsilon=self.epsilon)
Esempio n. 3
0
            def normalize_inference():
                if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
                    x_normed_running = K.batch_normalization(
                        inputs,
                        self.running_mean,
                        self.running_variance,
                        self.beta,
                        self.gamma,
                        epsilon=self.epsilon)

                    return x_normed_running
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_variance,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        inputs,
                        broadcast_running_mean,
                        broadcast_running_std,
                        broadcast_beta,
                        broadcast_gamma,
                        epsilon=self.epsilon)

                    return x_normed_running
Esempio n. 4
0
        def normalize_inference():
            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(repeated_moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(repeated_moving_variance,
                                                      broadcast_shape)

                broadcast_beta = K.reshape(repeated_beta, broadcast_shape)

                broadcast_gamma = K.reshape(repeated_gamma, broadcast_shape)

                return K.batch_normalization(inputs,
                                             broadcast_moving_mean,
                                             broadcast_moving_variance,
                                             broadcast_beta,
                                             broadcast_gamma,
                                             epsilon=self.epsilon)
            else:
                return K.batch_normalization(inputs,
                                             repeated_moving_mean,
                                             repeated_moving_variance,
                                             repeated_beta,
                                             repeated_gamma,
                                             epsilon=self.epsilon)
Esempio n. 5
0
 def normalize_inference():
     if needs_broadcasting:
         broadcast_moving_mean = K.reshape(self.moving_mean,
                                           broadcast_shape)
         broadcast_moving_variance = K.reshape(self.moving_variance,
                                               broadcast_shape)
         if self.center:
             broadcast_beta = K.reshape(self.beta, broadcast_shape)
         else:
             broadcast_beta = None
         if self.scale:
             broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
         else:
             broadcast_gamma = None
         return K.batch_normalization(inputs,
                                      broadcast_moving_mean,
                                      broadcast_moving_variance,
                                      broadcast_beta,
                                      broadcast_gamma,
                                      axis=self.axis,
                                      epsilon=self.epsilon)
     else:
         return K.batch_normalization(inputs,
                                      self.moving_mean,
                                      self.moving_variance,
                                      self.beta,
                                      self.gamma,
                                      axis=self.axis,
                                      epsilon=self.epsilon)
    def call(self, inputs, training=None):
        x = inputs
        assert not isinstance(x, list)

        # Do the normalization and the rescaling
        xnorm = K.batch_normalization(x,
                                      self.moving_mean,
                                      self.moving_variance,
                                      self.beta,
                                      self.gamma,
                                      epsilon=self.epsilon)

        # Compute and update the minibatch statistics
        if self.update_stats:
            mean, var = self._moments(x, axes=range(len(K.int_shape(x)) - 1))
            self.add_update([
                K.moving_average_update(self.moving_mean, mean, self.momentum),
                K.moving_average_update(self.moving_variance, var,
                                        self.momentum)
            ], x)

        return xnorm
        def normalize_inference():
            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(self.moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(self.moving_variance,
                                                      broadcast_shape)
                if self.center:
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                else:
                    broadcast_beta = None
                if self.scale:
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                else:
                    broadcast_gamma = None

                if self.quant_mode in ['hybrid', 'intrinsic']:
                    broadcast_moving_mean = quantizer_weight.quantize(
                        broadcast_moving_mean)
                    broadcast_moving_variance = quantizer_weight.quantize(
                        broadcast_moving_variance)
                    if self.center:
                        broadcast_beta = quantizer_weight.quantize(
                            broadcast_beta)
                    if self.scale:
                        broadcast_gamma = quantizer_weight.quantize(
                            broadcast_gamma)

                if self.quant_mode in ['hybrid', 'intrinsic']:
                    quantized_inputs = quantizer_input.quantize(inputs)

                if self.quant_mode == 'intrinsic':
                    return QuantizedBatchNormalizationCore(
                        quantized_inputs, broadcast_moving_mean,
                        broadcast_moving_variance, broadcast_beta,
                        broadcast_gamma, self.epsilon, quantizer_output)
                elif self.quant_mode == 'hybrid':
                    output = K.batch_normalization(quantized_inputs,
                                                   broadcast_moving_mean,
                                                   broadcast_moving_variance,
                                                   broadcast_beta,
                                                   broadcast_gamma,
                                                   axis=self.axis,
                                                   epsilon=self.epsilon)
                    return quantizer_output.quantize(output)
                elif self.quant_mode == 'extrinsic':
                    output = K.batch_normalization(inputs,
                                                   broadcast_moving_mean,
                                                   broadcast_moving_variance,
                                                   broadcast_beta,
                                                   broadcast_gamma,
                                                   axis=self.axis,
                                                   epsilon=self.epsilon)
                    return quantizer_output.quantize(output)
                elif self.quant_mode is None:
                    return K.batch_normalization(inputs,
                                                 broadcast_moving_mean,
                                                 broadcast_moving_variance,
                                                 broadcast_beta,
                                                 broadcast_gamma,
                                                 axis=self.axis,
                                                 epsilon=self.epsilon)

            else:
                if self.quant_mode in ['hybrid', 'intrinsic']:
                    moving_mean = quantizer_weight.quantize(self.moving_mean)
                    moving_variance = quantizer_weight.quantize(
                        self.moving_variance)
                    if self.center:
                        beta = quantizer_weight.quantize(self.beta)
                    else:
                        beta = self.beta
                    if self.scale:
                        gamma = quantizer_weight.quantize(self.gamma)
                    else:
                        gamma = self.gamma

                if self.quant_mode in ['hybrid', 'intrinsic']:
                    quantized_inputs = quantizer_input.quantize(inputs)

                if self.quant_mode == 'intrinsic':
                    return QuantizedBatchNormalizationCore(
                        quantized_inputs, moving_mean, moving_variance, beta,
                        gamma, self.epsilon, quantizer_output)
                elif self.quant_mode == 'hybrid':
                    output = K.batch_normalization(quantized_inputs,
                                                   moving_mean,
                                                   moving_variance,
                                                   beta,
                                                   gamma,
                                                   axis=self.axis,
                                                   epsilon=self.epsilon)
                    return quantizer_output.quantize(output)
                elif self.quant_mode == 'extrinsic':
                    output = K.batch_normalization(inputs,
                                                   self.moving_mean,
                                                   self.moving_variance,
                                                   self.beta,
                                                   self.gamma,
                                                   axis=self.axis,
                                                   epsilon=self.epsilon)
                    return quantizer_output.quantize(output)
                elif self.quant_mode == None:
                    return K.batch_normalization(inputs,
                                                 self.moving_mean,
                                                 self.moving_variance,
                                                 self.beta,
                                                 self.gamma,
                                                 axis=self.axis,
                                                 epsilon=self.epsilon)
Esempio n. 8
0
    def call(self, x, mask=None):
        if self.mode == 0 or self.mode == 2:
            assert self.built, 'Layer must be built before being called'
            input_shape = K.int_shape(x)

            reduction_axes = list(range(len(input_shape)))
            del reduction_axes[self.axis]
            broadcast_shape = [1] * len(input_shape)
            broadcast_shape[self.axis] = input_shape[self.axis]

            # mean_batch, var_batch = K.moments(x, reduction_axes, shift=None, keep_dims=False)
            normed, mean_batch, var_batch = K.normalize_batch_in_training(
                x, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)

            std_batch = (K.sqrt(var_batch + self.epsilon))

            r_max_value = K.get_value(self.r_max)
            r = std_batch / (K.sqrt(self.running_std + self.epsilon))
            r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

            d_max_value = K.get_value(self.d_max)
            d = (mean_batch - self.running_mean) / K.sqrt(self.running_std +
                                                          self.epsilon)
            d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

            if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                x_normed_batch = (x - mean_batch) / std_batch
                x_normed = (x_normed_batch * r + d) * self.gamma + self.beta
            else:
                # need broadcasting
                broadcast_mean = K.reshape(mean_batch, broadcast_shape)
                broadcast_std = K.reshape(std_batch, broadcast_shape)
                broadcast_r = K.reshape(r, broadcast_shape)
                broadcast_d = K.reshape(d, broadcast_shape)
                broadcast_beta = K.reshape(self.beta, broadcast_shape)
                broadcast_gamma = K.reshape(self.gamma, broadcast_shape)

                x_normed_batch = (x - broadcast_mean) / broadcast_std
                x_normed = (x_normed_batch * broadcast_r +
                            broadcast_d) * broadcast_gamma + broadcast_beta

            # explicit update to moving mean and standard deviation
            self.add_update([
                K.moving_average_update(self.running_mean, mean_batch,
                                        self.momentum),
                K.moving_average_update(self.running_std, std_batch**2,
                                        self.momentum)
            ], x)

            # update r_max and d_max
            t_val = K.get_value(self.t)
            r_val = self.r_max_value / (
                1 + (self.r_max_value - 1) * np.exp(-t_val))
            d_val = self.d_max_value / (1 + (
                (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
            t_val += float(self.t_delta)

            self.add_update([
                K.update(self.r_max, r_val),
                K.update(self.d_max, d_val),
                K.update(self.t, t_val)
            ], x)

            if self.mode == 0:
                if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                    x_normed_running = K.batch_normalization(
                        x,
                        self.running_mean,
                        self.running_std,
                        self.beta,
                        self.gamma,
                        epsilon=self.epsilon)
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_std,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        x,
                        broadcast_running_mean,
                        broadcast_running_std,
                        broadcast_beta,
                        broadcast_gamma,
                        epsilon=self.epsilon)

                # pick the normalized form of x corresponding to the training phase
                # for batch renormalization, inference time remains same as batchnorm
                x_normed = K.in_train_phase(x_normed, x_normed_running)

        elif self.mode == 1:
            # sample-wise normalization
            m = K.mean(x, axis=self.axis, keepdims=True)
            std = K.sqrt(
                K.var(x, axis=self.axis, keepdims=True) + self.epsilon)
            x_normed_batch = (x - m) / (std + self.epsilon)

            r_max_value = K.get_value(self.r_max)
            r = std / (self.running_std + self.epsilon)
            r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

            d_max_value = K.get_value(self.d_max)
            d = (m - self.running_mean) / (self.running_std + self.epsilon)
            d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

            x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta

            # update r_max and d_max
            t_val = K.get_value(self.t)
            r_val = self.r_max_value / (
                1 + (self.r_max_value - 1) * np.exp(-t_val))
            d_val = self.d_max_value / (1 + (
                (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
            t_val += float(self.t_delta)

            self.add_update([
                K.update(self.r_max, r_val),
                K.update(self.d_max, d_val),
                K.update(self.t, t_val)
            ], x)

        return x_normed