Exemplo n.º 1
0
        def training_phase():
            mean_batch = K.mean(mean_instance, axis=0, keepdims=True)
            variance_batch = K.mean(temp, axis=0,
                                    keepdims=True) - K.square(mean_batch)

            mean_batch_reshaped = K.flatten(mean_batch)
            variance_batch_reshaped = K.flatten(variance_batch)

            if K.backend() != 'cntk':
                sample_size = K.prod(
                    [K.shape(inputs)[axis] for axis in reduction_axes])
                sample_size = K.cast(sample_size, dtype=K.dtype(inputs))

                # sample variance - unbiased estimator of population variance
                variance_batch_reshaped *= sample_size / (sample_size -
                                                          (1.0 + self.epsilon))

            self.add_update([
                K.moving_average_update(self.moving_mean, mean_batch_reshaped,
                                        self.momentum),
                K.moving_average_update(self.moving_variance,
                                        variance_batch_reshaped, self.momentum)
            ], )

            return normalize_func(mean_batch, variance_batch)
    def call(self, inputs, training=None):
        x = inputs
        assert not isinstance(x, list)

        # Compute the minibatch statistics
        mean, var = self._moments(x)
        sigma = K.sqrt(var + self.epsilon)

        # If in training phase set rmax, dmax large so that we use the moving
        # averages to do the normalization
        rmax = K.in_train_phase(self.rmax, K.constant(1e5), training)
        dmax = K.in_train_phase(self.dmax, K.constant(1e5), training)

        # Compute the corrections based on rmax, dmax
        r = K.stop_gradient(
            self._clip(sigma / self.moving_sigma, 1. / rmax, rmax))
        d = K.stop_gradient(
            self._clip((mean - self.moving_mean) / self.moving_sigma, -dmax,
                       dmax))

        # Actually do the normalization and the rescaling
        xnorm = ((x - mean) / sigma) * r + d
        y = self.gamma * xnorm + self.beta

        # Add the moving average updates
        self.add_update([
            K.moving_average_update(self.moving_mean, mean, self.momentum),
            K.moving_average_update(self.moving_sigma, sigma, self.momentum)
        ], x)

        # Add the r, d updates
        rmax_prog = K.minimum(1., self.steps / self.rmax_dur)
        dmax_prog = K.minimum(1., self.steps / self.dmax_dur)
        self.add_update([
            K.update_add(self.steps, 1),
            K.update(self.rmax,
                     self.rmax_0 + rmax_prog * (self.rmax_inf - self.rmax_0)),
            K.update(self.dmax,
                     self.dmax_0 + dmax_prog * (self.dmax_inf - self.dmax_0))
        ])

        # Fix the output's uses learning phase
        y._uses_learning_phase = rmax._uses_learning_phase

        return y
Exemplo n.º 3
0
 def inject(self):
     """
     add moving average update op
     to model.metrics_updates
     """
     self.initialize()
     for w1, w2 in zip(self.ema_weights, self.model.weights):
         op = K.moving_average_update(w1, w2, self.momentum)
         self.model.metrics_updates.append(op)
Exemplo n.º 4
0
 def inject(self):
     """添加更新算子到model.metrics_updates。 
     """
     self.initialize()
     for w1, w2 in zip(self.ema_weights, self.model.weights):
         op = K.moving_average_update(w1, w2, self.momentum)
         #self.model.metrics_updates.append(op) # 在 keras 2.2.4 有效
         if not hasattr(self.model, '_other_metrics'):
             self.model._other_metrics = []
         self.model._other_metrics.append(op)
Exemplo n.º 5
0
        def update_branch():
            """ Update the moving average when is_ema_training is True."""

            # Set the qnoise factor to 0 to update the EMA using the unquantized input
            prev_qnoise_factor = tf.identity(self.quantizer.qnoise_factor)
            self.quantizer.update_qnoise_factor(tf.constant(0.0))

            # Update the EMA
            act_x = self.quantizer(
                x)  # act_x is the input after the activation
            # function, but before the quantizer. This is
            # done by using a qnoise_factor of 0
            new_min = tf.squeeze(K.min(act_x, axis=axis, keepdims=True))
            K.moving_average_update(self.ema_min, new_min, self.ema_decay)
            new_max = tf.squeeze(K.max(act_x, axis=axis, keepdims=True))
            K.moving_average_update(self.ema_max, new_max, self.ema_decay)

            # Reset the qnoise factor to the previous value
            self.quantizer.update_qnoise_factor(prev_qnoise_factor)
Exemplo n.º 6
0
 def set_model(self, model):
     """绑定模型,并初始化参数
     """
     super(ExponentialMovingAverage, self).set_model(model)
     self.ema_weights = [K.zeros(K.shape(w)) for w in model.weights]
     self.old_weights = K.batch_get_value(model.weights)
     K.batch_set_value(zip(self.ema_weights, self.old_weights))
     self.updates = []
     for w1, w2 in zip(self.ema_weights, model.weights):
         op = K.moving_average_update(w1, w2, self.momentum)
         self.updates.append(op)
    def call(self, inputs, training=None):
        x = inputs
        assert not isinstance(x, list)

        # Do the normalization and the rescaling
        xnorm = K.batch_normalization(x,
                                      self.moving_mean,
                                      self.moving_variance,
                                      self.beta,
                                      self.gamma,
                                      epsilon=self.epsilon)

        # Compute and update the minibatch statistics
        if self.update_stats:
            mean, var = self._moments(x, axes=range(len(K.int_shape(x)) - 1))
            self.add_update([
                K.moving_average_update(self.moving_mean, mean, self.momentum),
                K.moving_average_update(self.moving_variance, var,
                                        self.momentum)
            ], x)

        return xnorm
Exemplo n.º 8
0
        def training_phase():
            # Depthwise-Conv mit Soft-Relu
            dconvs = K.depthwise_conv2d(inputs,
                                        self.depthwise_kernel,
                                        strides=self.strides,
                                        padding=self.padding,
                                        data_format='channels_last')

            #            dconvs = tf.where(dconvs<=2**(-self.L_A[1]-1), tf.zeros_like(dconvs), dconvs)
            #            factor2 = 0.9*self.max_activity
            #            dconvs = K.minimum(dconvs, 0.1*dconvs+factor2)
            factor2 = 0.9 * self.max_activity_signed
            dconvs = K.minimum(dconvs, 0.1 * dconvs + factor2)
            dconvs = K.maximum(dconvs, 0.1 * dconvs - factor2)

            # Pointwise-Conv
            convs = K.conv2d(dconvs,
                             self.kernel,
                             strides=(1, 1),
                             padding=self.padding,
                             data_format='channels_last',
                             dilation_rate=self.dilation_rate)
            convs = K.bias_add(convs, self.bias, data_format='channels_last')

            # Skalierung
            scale1 = K.abs(self.max_activity_x /
                           (K.max(K.abs(convs), axis=(0, 1, 2)) + 1e-6))
            indizes = K.greater(scale1, self.max_scale)
            scale1 = self.w_scale * tf.to_float(indizes) + tf.to_float(
                ~indizes) * scale1

            scale2 = self.max_weight / (K.maximum(
                tf.abs(self.bias),
                tf.reduce_max(tf.abs(self.kernel), axis=(0, 1, 2))) + 1e-6)
            scale = K.minimum(scale1, scale2)

            self.add_update(
                K.moving_average_update(self.w_scale, scale, self.momentum),
                inputs)
            # Softclipped-linear
            outputs = convs * self.w_scale
            #            outputs = K.clip(outputs, min_value=-self.max_activity_signed, max_value=self.max_activity_signed)
            outputs = tf.where(outputs <= 2**(-self.L_A[1] - 1),
                               tf.zeros_like(outputs), outputs)
            outputs = K.minimum(outputs, 0.1 * outputs + factor2)
            return outputs
Exemplo n.º 9
0
        def training_phase():
            convs = K.conv2d(inputs,
                             self.kernel,
                             data_format='channels_last',
                             strides=self.strides,
                             padding=self.padding,
                             dilation_rate=self.dilation_rate)

            if self.use_bias:
                if self.data_format == 'channels_last':
                    convs = K.bias_add(convs,
                                       self.bias,
                                       data_format='channels_last')
                scale2 = self.max_weight / (K.maximum(
                    tf.abs(self.bias),
                    tf.reduce_max(tf.abs(self.kernel), axis=(0, 1, 2))) + 1e-6)
            else:
                scale2 = self.max_weight / (
                    tf.reduce_max(tf.abs(self.kernel), axis=(0, 1, 2)) + 1e-6)

            indizes = K.greater(K.max(convs, axis=(0, 1, 2)), 0.01)
            scale1 = self.w_scale * tf.cast(~indizes, tf.float32) + tf.cast(
                indizes,
                tf.float32) * K.abs(self.max_activity_x /
                                    (K.max(convs, axis=(0, 1, 2)) + 1e-6))

            scale = K.minimum(K.minimum(scale1, scale2), self.max_scale)

            self.add_update(
                K.moving_average_update(self.w_scale, scale, self.momentum))

            outputs = convs * self.w_scale
            if self.data_format == 'channels_last':
                outputs = tf.transpose(outputs, [0, 3, 1, 2])
                outputs = tf.where(outputs <= 2**(-self.L_A[1] - 1),
                                   tf.zeros_like(outputs), outputs)
                outputs = tf.transpose(outputs, [0, 2, 3, 1])
            else:
                outputs = tf.where(outputs <= 2**(-self.L_A[1] - 1),
                                   tf.zeros_like(outputs), outputs)
            outputs = K.minimum(outputs,
                                self.max_activity)  #0.1*outputs+self.factor)
            return outputs
Exemplo n.º 10
0
    def call(self, inputs, training=False):
        x = inputs
        training = training and self.trainable
        self.will_ema_freeze = self.will_ema_freeze and self.trainable

        # Update the step count if the optimizer step count is unknown
        self.step.assign_add(
            K.switch(
                tf.math.logical_and(self.is_estimating_step_count, training),
                tf.constant(1, tf.int64), tf.constant(0, tf.int64)))

        # Perform the quantization
        if training:
            # Calculate the qnoise, a scalar from 0 to 1 that represents the level of
            # quantization noise to use. At training start, we want no quantization,
            # so qnoise_factor = 0.0. After quantization_delay steps, we want normal
            # quantization, so qnoise_factor = 1.0.
            qnoise_factor = K.switch(
                tf.greater_equal(self.step, self.quantization_delay),
                lambda: tf.constant(1.0), lambda: tf.constant(0.0))
            qx = self.quantizer(x, qnoise_factor=qnoise_factor)

        else:  # If not training, we always want to use full quantization
            qx = self.quantizer(x, qnoise_factor=tf.constant(1.0))

        # Calculate the axis along where to find the min and max EMAs
        len_axis = len(x.shape)
        if len_axis > 1:
            if self.per_channel:
                if K.image_data_format() == "channels_last":
                    axis = list(range(len_axis - 1))
                else:
                    axis = list(range(1, len_axis))
            else:
                axis = list(range(len_axis))
        else:
            axis = [0]

        # Determine if freezing the EMA
        is_ema_training = tf.constant(training, dtype=tf.bool)
        if self.will_ema_freeze:
            is_ema_training = tf.cond(
                tf.greater(self.step, self.ema_freeze_delay),
                lambda: tf.constant(False), lambda: tf.constant(True))

        # Update the moving average
        if is_ema_training:
            new_min = tf.squeeze(K.min(qx, axis=axis, keepdims=True))
            K.moving_average_update(self.ema_min, new_min, self.ema_decay)
            new_max = tf.squeeze(K.max(qx, axis=axis, keepdims=True))
            K.moving_average_update(self.ema_max, new_max, self.ema_decay)

        # Set the integer bits for the quantizer
        integer_bits = _get_integer_bits(min_value=self.ema_min,
                                         max_value=self.ema_max,
                                         bits=self.total_bits,
                                         symmetric=self.symmetric,
                                         keep_negative=self.keep_negative,
                                         is_clipping=self.po2_rounding)
        self.quantizer.integer.assign(integer_bits)

        return qx
Exemplo n.º 11
0
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        def normalize_inference():
            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(self.moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(self.moving_variance,
                                                      broadcast_shape)
                if self.center:
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                else:
                    broadcast_beta = None
                if self.scale:
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                else:
                    broadcast_gamma = None
                return tf.nn.batch_normalization(  #K.batch_normalization(
                    inputs,
                    broadcast_moving_mean,
                    broadcast_moving_variance,
                    broadcast_beta,
                    broadcast_gamma,
                    #axis=self.axis,
                    self.epsilon)  #epsilon=self.epsilon)
            else:
                return tf.nn.batch_normalization(  #K.batch_normalization(
                    inputs,
                    self.moving_mean,
                    self.moving_variance,
                    self.beta,
                    self.gamma,
                    #axis=self.axis,
                    self.epsilon)  #epsilon=self.epsilon)

        # If the learning phase is *static* and set to inference:
        if training in {0, False}:
            return normalize_inference()

        # If the learning is either dynamic, or set to training:
        normed_training, mean, variance = _regular_normalize_batch_in_training(  #K.normalize_batch_in_training(
            inputs,
            self.gamma,
            self.beta,
            reduction_axes,
            epsilon=self.epsilon)

        if K.backend() != 'cntk':
            sample_size = K.prod(
                [K.shape(inputs)[axis] for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs))

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        self.add_update([
            K.moving_average_update(self.moving_mean, mean, self.momentum),
            K.moving_average_update(self.moving_variance, variance,
                                    self.momentum)
        ], inputs)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)
    def call(self, inputs, training=None):
        if self.quant_mode not in [None, 'extrinsic', 'hybrid', 'intrinsic']:
            raise ValueError(
                'Invalid quantization mode. The \'quant_mode\' argument must be one of \'extrinsic\' , \'intrinsic\' , \'hybrid\' or None.'
            )

        if isinstance(self.quantizer, list) and len(self.quantizer) == 3:
            quantizer_input = self.quantizer[0]
            quantizer_weight = self.quantizer[1]
            quantizer_output = self.quantizer[2]
        else:
            quantizer_input = self.quantizer
            quantizer_weight = self.quantizer
            quantizer_output = self.quantizer

        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        def normalize_inference():
            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(self.moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(self.moving_variance,
                                                      broadcast_shape)
                if self.center:
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                else:
                    broadcast_beta = None
                if self.scale:
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                else:
                    broadcast_gamma = None

                if self.quant_mode in ['hybrid', 'intrinsic']:
                    broadcast_moving_mean = quantizer_weight.quantize(
                        broadcast_moving_mean)
                    broadcast_moving_variance = quantizer_weight.quantize(
                        broadcast_moving_variance)
                    if self.center:
                        broadcast_beta = quantizer_weight.quantize(
                            broadcast_beta)
                    if self.scale:
                        broadcast_gamma = quantizer_weight.quantize(
                            broadcast_gamma)

                if self.quant_mode in ['hybrid', 'intrinsic']:
                    quantized_inputs = quantizer_input.quantize(inputs)

                if self.quant_mode == 'intrinsic':
                    return QuantizedBatchNormalizationCore(
                        quantized_inputs, broadcast_moving_mean,
                        broadcast_moving_variance, broadcast_beta,
                        broadcast_gamma, self.epsilon, quantizer_output)
                elif self.quant_mode == 'hybrid':
                    output = K.batch_normalization(quantized_inputs,
                                                   broadcast_moving_mean,
                                                   broadcast_moving_variance,
                                                   broadcast_beta,
                                                   broadcast_gamma,
                                                   axis=self.axis,
                                                   epsilon=self.epsilon)
                    return quantizer_output.quantize(output)
                elif self.quant_mode == 'extrinsic':
                    output = K.batch_normalization(inputs,
                                                   broadcast_moving_mean,
                                                   broadcast_moving_variance,
                                                   broadcast_beta,
                                                   broadcast_gamma,
                                                   axis=self.axis,
                                                   epsilon=self.epsilon)
                    return quantizer_output.quantize(output)
                elif self.quant_mode is None:
                    return K.batch_normalization(inputs,
                                                 broadcast_moving_mean,
                                                 broadcast_moving_variance,
                                                 broadcast_beta,
                                                 broadcast_gamma,
                                                 axis=self.axis,
                                                 epsilon=self.epsilon)

            else:
                if self.quant_mode in ['hybrid', 'intrinsic']:
                    moving_mean = quantizer_weight.quantize(self.moving_mean)
                    moving_variance = quantizer_weight.quantize(
                        self.moving_variance)
                    if self.center:
                        beta = quantizer_weight.quantize(self.beta)
                    else:
                        beta = self.beta
                    if self.scale:
                        gamma = quantizer_weight.quantize(self.gamma)
                    else:
                        gamma = self.gamma

                if self.quant_mode in ['hybrid', 'intrinsic']:
                    quantized_inputs = quantizer_input.quantize(inputs)

                if self.quant_mode == 'intrinsic':
                    return QuantizedBatchNormalizationCore(
                        quantized_inputs, moving_mean, moving_variance, beta,
                        gamma, self.epsilon, quantizer_output)
                elif self.quant_mode == 'hybrid':
                    output = K.batch_normalization(quantized_inputs,
                                                   moving_mean,
                                                   moving_variance,
                                                   beta,
                                                   gamma,
                                                   axis=self.axis,
                                                   epsilon=self.epsilon)
                    return quantizer_output.quantize(output)
                elif self.quant_mode == 'extrinsic':
                    output = K.batch_normalization(inputs,
                                                   self.moving_mean,
                                                   self.moving_variance,
                                                   self.beta,
                                                   self.gamma,
                                                   axis=self.axis,
                                                   epsilon=self.epsilon)
                    return quantizer_output.quantize(output)
                elif self.quant_mode == None:
                    return K.batch_normalization(inputs,
                                                 self.moving_mean,
                                                 self.moving_variance,
                                                 self.beta,
                                                 self.gamma,
                                                 axis=self.axis,
                                                 epsilon=self.epsilon)

        # If the learning phase is *static* and set to inference:
        if not training:
            return normalize_inference()

        # If the learning is either dynamic, or set to training:
        normed_training, mean, variance = K.normalize_batch_in_training(
            inputs,
            self.gamma,
            self.beta,
            reduction_axes,
            epsilon=self.epsilon)

        if K.backend() != 'cntk':
            sample_size = K.prod(
                [K.shape(inputs)[axis] for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs))

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        self.add_update([
            K.moving_average_update(self.moving_mean, mean, self.momentum),
            K.moving_average_update(self.moving_variance, variance,
                                    self.momentum)
        ], inputs)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)
Exemplo n.º 13
0
    def call(self, inputs, training=None):
        assert self.built, 'Layer must be built before being called'
        input_shape = K.int_shape(inputs)

        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        mean_batch, var_batch = K.moments(inputs,
                                          reduction_axes,
                                          shift=None,
                                          keep_dims=False)
        std_batch = (K.sqrt(var_batch + self.epsilon))

        r_max_value = K.get_value(self.r_max)
        r = std_batch / (K.sqrt(self.running_variance + self.epsilon))
        r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

        d_max_value = K.get_value(self.d_max)
        d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance +
                                                      self.epsilon)
        d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

        if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
            x_normed_batch = (inputs - mean_batch) / std_batch
            x_normed = (x_normed_batch * r + d) * self.gamma + self.beta
        else:
            # need broadcasting
            broadcast_mean = K.reshape(mean_batch, broadcast_shape)
            broadcast_std = K.reshape(std_batch, broadcast_shape)
            broadcast_r = K.reshape(r, broadcast_shape)
            broadcast_d = K.reshape(d, broadcast_shape)
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)

            x_normed_batch = (inputs - broadcast_mean) / broadcast_std
            x_normed = (x_normed_batch * broadcast_r +
                        broadcast_d) * broadcast_gamma + broadcast_beta

        # explicit update to moving mean and standard deviation
        self.add_update([
            K.moving_average_update(self.running_mean, mean_batch,
                                    self.momentum),
            K.moving_average_update(self.running_variance, std_batch**2,
                                    self.momentum)
        ], inputs)

        # update r_max and d_max
        t_val = K.get_value(self.t)
        r_val = self.r_max_value / (1 +
                                    (self.r_max_value - 1) * np.exp(-t_val))
        d_val = self.d_max_value / (1 + (
            (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
        t_val += float(self.t_delta)

        self.add_update([
            K.update(self.r_max, r_val),
            K.update(self.d_max, d_val),
            K.update(self.t, t_val)
        ], inputs)

        if training in {0, False}:
            return x_normed
        else:

            def normalize_inference():
                if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
                    x_normed_running = K.batch_normalization(
                        inputs,
                        self.running_mean,
                        self.running_variance,
                        self.beta,
                        self.gamma,
                        epsilon=self.epsilon)

                    return x_normed_running
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_variance,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        inputs,
                        broadcast_running_mean,
                        broadcast_running_std,
                        broadcast_beta,
                        broadcast_gamma,
                        epsilon=self.epsilon)

                    return x_normed_running

            # pick the normalized form of inputs corresponding to the training phase
            # for batch renormalization, inference time remains same as batchnorm
            x_normed = K.in_train_phase(x_normed,
                                        normalize_inference,
                                        training=training)

            return x_normed
Exemplo n.º 14
0
    def call(self, inputs, training = None):

        input_shape = K.int_shape(inputs) # .shape
        ndim        = len(input_shape) # 4

        reduction_axes = list(range(ndim)) # If ndim == 4, list(range(ndim)) == [0, 1, 2, 3]
        del reduction_axes[self.axis] # --> [0, 1, 2], self.axis == -1

        input_dim = input_shape[self.axis] // 2

        mu = K.mean(inputs, axis = reduction_axes) # real mu, imag mu

        broadcast_mu_shape            = [1] * len(input_shape) # [1, 1, 1, 1]
        broadcast_mu_shape[self.axis] = input_shape[self.axis] # [1, 1, 1, input_shape[self.axis]]
        broadcast_mu                  = K.reshape(mu, broadcast_mu_shape) # mu shape is [1, 1, 1, 2]

        """
        real parts에는 real mean을 빼고
        imag parts에는 imag mean을 뺀다
        centred_squared == (x - E(x))^2
        """
        if self.center:
            input_centred = inputs - broadcast_mu
        else:
            input_centred = inputs

        centred_squared = input_centred ** 2

        'for Conv2D'
        centred_squared_real = centred_squared[:, :, :, :input_dim] # real
        centred_squared_imag = centred_squared[:, :, :, input_dim:] # imag
        centred_real = input_centred[:, :, :, :input_dim] # real
        centred_imag = input_centred[:, :, :, input_dim:] # imag

        if self.scale:
            Vrr = K.mean(centred_squared_real, axis=reduction_axes) + self.epsilon
            Vii = K.mean(centred_squared_imag, axis=reduction_axes) + self.epsilon
            Vri = K.mean(centred_real * centred_imag, axis=reduction_axes,) # Vri contains the real and imaginary covariance for each feature map.
        elif self.center:
            Vrr = None
            Vii = None
            Vri = None
        else:
            raise ValueError('Error. Both scale and center in batchnorm are set to False.')

        """
        1. Calcultae BatchNormalization for real parts, imag parts of complex numbers
        2. If Training == True, Under self.center and self.scale condition, Update parameter moving mean, moving_Vrr, moving_Vii, moving_Vri
        """
        input_bn = complex_batchnorm(input_centred, Vrr, Vii, Vri, self.beta, self.gamma_rr, self.gamma_ri, self.gamma_ii, self.scale, self.center, axis = self.axis)

        if training in {0, False}:
            return input_bn
        else: # traning is True!!!
            update_list = []
            if self.center:
                update_list.append(K.moving_average_update(self.moving_mean, mu, self.momentum))
            if self.scale:
                update_list.append(K.moving_average_update(self.moving_Vrr, Vrr, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vii, Vii, self.momentum))
                update_list.append(K.moving_average_update(self.moving_Vri, Vri, self.momentum))
            self.add_update(update_list, inputs)

            def normalize_inference():
                if self.center:
                    inference_centred = inputs - K.reshape(self.moving_mean, broadcast_mu_shape)
                else:
                    inference_centred = inputs
                return complex_batchnorm(inference_centred, 
                                self.moving_Vrr, self.moving_Vii, self.moving_Vri, self.beta, 
                                self.gamma_rr, self.gamma_ri, self.gamma_ii, self.scale, self.center, axis = self.axis)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(input_bn, normalize_inference, training = training)
Exemplo n.º 15
0
 def _moving_average(self, var, value, momentum):
     if self._tf1:
         return self._assign(var, var * momentum + value * (1 - momentum))
     result = K.moving_average_update(var, value, momentum)
     self._updates.append(result)
     return result
Exemplo n.º 16
0
 def average_op(itself, var, average_var):
     decay = tf.constant(self.hull.decay_fn(self.hull.step), dtype=tf.float32)
     return backend.moving_average_update(average_var, var, decay)
Exemplo n.º 17
0
    def call(self, x, mask=None):
        if self.mode == 0 or self.mode == 2:
            assert self.built, 'Layer must be built before being called'
            input_shape = K.int_shape(x)

            reduction_axes = list(range(len(input_shape)))
            del reduction_axes[self.axis]
            broadcast_shape = [1] * len(input_shape)
            broadcast_shape[self.axis] = input_shape[self.axis]

            # mean_batch, var_batch = K.moments(x, reduction_axes, shift=None, keep_dims=False)
            normed, mean_batch, var_batch = K.normalize_batch_in_training(
                x, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)

            std_batch = (K.sqrt(var_batch + self.epsilon))

            r_max_value = K.get_value(self.r_max)
            r = std_batch / (K.sqrt(self.running_std + self.epsilon))
            r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

            d_max_value = K.get_value(self.d_max)
            d = (mean_batch - self.running_mean) / K.sqrt(self.running_std +
                                                          self.epsilon)
            d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

            if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                x_normed_batch = (x - mean_batch) / std_batch
                x_normed = (x_normed_batch * r + d) * self.gamma + self.beta
            else:
                # need broadcasting
                broadcast_mean = K.reshape(mean_batch, broadcast_shape)
                broadcast_std = K.reshape(std_batch, broadcast_shape)
                broadcast_r = K.reshape(r, broadcast_shape)
                broadcast_d = K.reshape(d, broadcast_shape)
                broadcast_beta = K.reshape(self.beta, broadcast_shape)
                broadcast_gamma = K.reshape(self.gamma, broadcast_shape)

                x_normed_batch = (x - broadcast_mean) / broadcast_std
                x_normed = (x_normed_batch * broadcast_r +
                            broadcast_d) * broadcast_gamma + broadcast_beta

            # explicit update to moving mean and standard deviation
            self.add_update([
                K.moving_average_update(self.running_mean, mean_batch,
                                        self.momentum),
                K.moving_average_update(self.running_std, std_batch**2,
                                        self.momentum)
            ], x)

            # update r_max and d_max
            t_val = K.get_value(self.t)
            r_val = self.r_max_value / (
                1 + (self.r_max_value - 1) * np.exp(-t_val))
            d_val = self.d_max_value / (1 + (
                (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
            t_val += float(self.t_delta)

            self.add_update([
                K.update(self.r_max, r_val),
                K.update(self.d_max, d_val),
                K.update(self.t, t_val)
            ], x)

            if self.mode == 0:
                if sorted(reduction_axes) == range(K.ndim(x))[:-1]:
                    x_normed_running = K.batch_normalization(
                        x,
                        self.running_mean,
                        self.running_std,
                        self.beta,
                        self.gamma,
                        epsilon=self.epsilon)
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_std,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        x,
                        broadcast_running_mean,
                        broadcast_running_std,
                        broadcast_beta,
                        broadcast_gamma,
                        epsilon=self.epsilon)

                # pick the normalized form of x corresponding to the training phase
                # for batch renormalization, inference time remains same as batchnorm
                x_normed = K.in_train_phase(x_normed, x_normed_running)

        elif self.mode == 1:
            # sample-wise normalization
            m = K.mean(x, axis=self.axis, keepdims=True)
            std = K.sqrt(
                K.var(x, axis=self.axis, keepdims=True) + self.epsilon)
            x_normed_batch = (x - m) / (std + self.epsilon)

            r_max_value = K.get_value(self.r_max)
            r = std / (self.running_std + self.epsilon)
            r = K.stop_gradient(K.clip(r, 1 / r_max_value, r_max_value))

            d_max_value = K.get_value(self.d_max)
            d = (m - self.running_mean) / (self.running_std + self.epsilon)
            d = K.stop_gradient(K.clip(d, -d_max_value, d_max_value))

            x_normed = ((x_normed_batch * r) + d) * self.gamma + self.beta

            # update r_max and d_max
            t_val = K.get_value(self.t)
            r_val = self.r_max_value / (
                1 + (self.r_max_value - 1) * np.exp(-t_val))
            d_val = self.d_max_value / (1 + (
                (self.d_max_value / 1e-3) - 1) * np.exp(-(2 * t_val)))
            t_val += float(self.t_delta)

            self.add_update([
                K.update(self.r_max, r_val),
                K.update(self.d_max, d_val),
                K.update(self.t, t_val)
            ], x)

        return x_normed
Exemplo n.º 18
0
    def call(self, inputs, training=None):
        # These were moved here from build() because tf2 eager was not
        # tracking gradients:
        repeated_gamma = K.reshape(
            K.tile(K.expand_dims(self.gamma, -1), [1, self.n]),
            [-1],
        )
        repeated_beta = K.reshape(
            K.tile(K.expand_dims(self.beta, -1), [1, self.n]),
            [-1],
        )

        repeated_moving_mean = K.reshape(
            K.tile(K.expand_dims(self.moving_mean, -1), [1, self.n]),
            [-1],
        )
        repeated_moving_variance = K.reshape(
            K.tile(K.expand_dims(self.moving_variance, -1), [1, self.n]),
            [-1],
        )

        def unrepeat(w):
            n = 1
            if self.h == 'C4':
                n *= 4
            elif self.h == 'D4':
                n *= 8
            elif self.h == 'Z2':
                n *= 1
            else:
                raise ValueError('Wrong h: %s' % self.h)

            return K.mean(K.reshape(w, (K.int_shape(w)[0] // n, n)), -1)

        input_shape = K.int_shape(inputs)
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        def normalize_inference():
            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(repeated_moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(repeated_moving_variance,
                                                      broadcast_shape)

                broadcast_beta = K.reshape(repeated_beta, broadcast_shape)

                broadcast_gamma = K.reshape(repeated_gamma, broadcast_shape)

                return K.batch_normalization(inputs,
                                             broadcast_moving_mean,
                                             broadcast_moving_variance,
                                             broadcast_beta,
                                             broadcast_gamma,
                                             epsilon=self.epsilon)
            else:
                return K.batch_normalization(inputs,
                                             repeated_moving_mean,
                                             repeated_moving_variance,
                                             repeated_beta,
                                             repeated_gamma,
                                             epsilon=self.epsilon)

        def _get_training_value(training, trainable_flag):
            """
            Return a flag indicating whether a layer should be called in training
            or inference mode.
            Modified from https://git.io/JUGHX
            training: the setting used when layer is called for inference.
            trainable: flag indicating whether the layer is trainable.
            """
            if training is None:
                training = K.learning_phase()

            if isinstance(training, int):
                training = bool(training)

            # If layer not trainable, override value passed from model.
            if trainable_flag is False:
                training = False

            return training

        # If the learning phase is *static* and set to inference:
        training_val = _get_training_value(training, self.trainable)
        if training_val is False:
            return normalize_inference()

        # If the learning is either dynamic, or set to training:
        normed_training, mean, variance = K.normalize_batch_in_training(
            inputs,
            repeated_gamma,
            repeated_beta,
            reduction_axes,
            epsilon=self.epsilon)

        if K.backend() != 'cntk':
            sample_size = K.prod(
                [K.shape(inputs)[axis] for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs))

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        self.add_update([
            K.moving_average_update(self.moving_mean, unrepeat(mean),
                                    self.momentum),
            K.moving_average_update(self.moving_variance, unrepeat(variance),
                                    self.momentum)
        ], inputs)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)
Exemplo n.º 19
0
    def call(self, inputs, training=None):
        input_shape = K.int_shape(inputs)
        ndim = len(input_shape)
        reduction_axes = list(range(ndim))
        del reduction_axes[self.axis]
        input_dim = input_shape[self.axis] // 2
        mu = K.mean(inputs, axis=reduction_axes)
        broadcast_mu_shape = [1] * len(input_shape)
        broadcast_mu_shape[self.axis] = input_shape[self.axis]
        broadcast_mu = K.reshape(mu, broadcast_mu_shape)
        if self.center:
            input_centred = inputs - broadcast_mu
        else:
            input_centred = inputs
        centred_squared = input_centred**2
        if (self.axis == 1 and ndim != 3) or ndim == 2:
            centred_squared_real = centred_squared[:, :input_dim]
            centred_squared_imag = centred_squared[:, input_dim:]
            centred_real = input_centred[:, :input_dim]
            centred_imag = input_centred[:, input_dim:]
        elif ndim == 3:
            centred_squared_real = centred_squared[:, :, :input_dim]
            centred_squared_imag = centred_squared[:, :, input_dim:]
            centred_real = input_centred[:, :, :input_dim]
            centred_imag = input_centred[:, :, input_dim:]
        elif self.axis == -1 and ndim == 4:
            centred_squared_real = centred_squared[:, :, :, :input_dim]
            centred_squared_imag = centred_squared[:, :, :, input_dim:]
            centred_real = input_centred[:, :, :, :input_dim]
            centred_imag = input_centred[:, :, :, input_dim:]
        elif self.axis == -1 and ndim == 5:
            centred_squared_real = centred_squared[:, :, :, :, :input_dim]
            centred_squared_imag = centred_squared[:, :, :, :, input_dim:]
            centred_real = input_centred[:, :, :, :, :input_dim]
            centred_imag = input_centred[:, :, :, :, input_dim:]
        else:
            raise ValueError(
                'Incorrect Batchnorm combination of axis and dimensions. axis should be either 1 or -1. '
                'axis: ' + str(self.axis) + '; ndim: ' + str(ndim) + '.')
        if self.scale:
            Vrr = K.mean(centred_squared_real,
                         axis=reduction_axes) + self.epsilon
            Vii = K.mean(centred_squared_imag,
                         axis=reduction_axes) + self.epsilon
            # Vri contains the real and imaginary covariance for each feature map.
            Vri = K.mean(
                centred_real * centred_imag,
                axis=reduction_axes,
            )
        elif self.center:
            Vrr = None
            Vii = None
            Vri = None
        else:
            raise ValueError(
                'Error. Both scale and center in batchnorm are set to False.')

        input_bn = ComplexBN(input_centred,
                             Vrr,
                             Vii,
                             Vri,
                             self.beta,
                             self.gamma_rr,
                             self.gamma_ri,
                             self.gamma_ii,
                             self.scale,
                             self.center,
                             axis=self.axis)
        if training in {0, False}:
            return input_bn
        else:
            update_list = []
            if self.center:
                update_list.append(
                    K.moving_average_update(self.moving_mean, mu,
                                            self.momentum))
            if self.scale:
                update_list.append(
                    K.moving_average_update(self.moving_Vrr, Vrr,
                                            self.momentum))
                update_list.append(
                    K.moving_average_update(self.moving_Vii, Vii,
                                            self.momentum))
                update_list.append(
                    K.moving_average_update(self.moving_Vri, Vri,
                                            self.momentum))
            self.add_update(update_list)

            def normalize_inference():
                if self.center:
                    inference_centred = inputs - K.reshape(
                        self.moving_mean, broadcast_mu_shape)
                else:
                    inference_centred = inputs
                return ComplexBN(inference_centred,
                                 self.moving_Vrr,
                                 self.moving_Vii,
                                 self.moving_Vri,
                                 self.beta,
                                 self.gamma_rr,
                                 self.gamma_ri,
                                 self.gamma_ii,
                                 self.scale,
                                 self.center,
                                 axis=self.axis)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(input_bn,
                                normalize_inference,
                                training=training)
Exemplo n.º 20
0
 def _get_update_list(self, kernel):
     self.moving_heatmap.assign(self.heatmap_momentum * self.moving_heatmap + (1. - self.heatmap_momentum) * K.sign(kernel))
     update_list = [
         K.moving_average_update(self.moving_heatmap, K.sign(kernel), self.heatmap_momentum),
     ]
     return update_list