def call(self, inputs, **kwargs):
        input_shape = K.int_shape(inputs)
        tensor_input_shape = K.shape(inputs)

        # Prepare broadcasting shape.
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
        broadcast_shape.insert(1, self.groups)

        reshape_group_shape = K.shape(inputs)
        group_axes = [reshape_group_shape[i] for i in range(len(input_shape))]
        group_axes[self.axis] = input_shape[self.axis] // self.groups
        group_axes.insert(1, self.groups)

        # reshape inputs to new group shape
        group_shape = [group_axes[0], self.groups] + group_axes[2:]
        group_shape = K.stack(group_shape)
        inputs = K.reshape(inputs, group_shape)

        group_reduction_axes = list(range(len(group_axes)))
        mean, variance = KC.moments(inputs,
                                    group_reduction_axes[2:],
                                    keep_dims=True)
        inputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))

        # prepare broadcast shape
        inputs = K.reshape(inputs, group_shape)

        outputs = inputs

        # In this case we must explicitly broadcast all parameters.
        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            outputs = outputs * broadcast_gamma

        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            outputs = outputs + broadcast_beta

        # finally we reshape the output back to the input shape
        outputs = K.reshape(outputs, tensor_input_shape)

        return outputs
Beispiel #2
0
    def test_moments(self, keep_dims):
        input_shape = (10, 10, 10, 10)
        x_0 = np.zeros(input_shape)
        x_1 = np.ones(input_shape)
        x_random = np.random.random(input_shape)

        th_axes = [0, 2, 3]
        tf_axes = [0, 1, 2]

        for ip in [x_0, x_1, x_random]:
            for axes in [th_axes, tf_axes]:
                K_mean, K_var = KC.moments(K.variable(ip), axes, keep_dims=keep_dims)
                np_mean, np_var = KCNP.moments(ip, axes, keep_dims=keep_dims)

                K_mean_val = K.eval(K_mean)
                K_var_val = K.eval(K_var)

                # absolute tolerance needed when working with zeros
                assert_allclose(K_mean_val, np_mean, rtol=1e-4, atol=1e-10)
                assert_allclose(K_var_val, np_var, rtol=1e-4, atol=1e-10)
    def call(self, inputs, training=None):
        assert self.built, 'Layer must be built before being called'
        input_shape = K.int_shape(inputs)

        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        mean_batch, var_batch = KC.moments(inputs, reduction_axes,
                                           shift=None, keep_dims=False)
        std_batch = (K.sqrt(var_batch + self.epsilon))

        r = std_batch / (K.sqrt(self.running_variance + self.epsilon))
        r = K.stop_gradient(K.clip(r, 1 / self.r_max, self.r_max))

        d = (mean_batch - self.running_mean) / K.sqrt(self.running_variance
                                                      + self.epsilon)
        d = K.stop_gradient(K.clip(d, -self.d_max, self.d_max))

        if sorted(reduction_axes) == range(K.ndim(inputs))[:-1]:
            x_normed_batch = (inputs - mean_batch) / std_batch
            x_normed = (x_normed_batch * r + d) * self.gamma + self.beta
        else:
            # need broadcasting
            broadcast_mean = K.reshape(mean_batch, broadcast_shape)
            broadcast_std = K.reshape(std_batch, broadcast_shape)
            broadcast_r = K.reshape(r, broadcast_shape)
            broadcast_d = K.reshape(d, broadcast_shape)
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)

            x_normed_batch = (inputs - broadcast_mean) / broadcast_std
            x_normed = (x_normed_batch * broadcast_r
                        + broadcast_d) * broadcast_gamma + broadcast_beta

        # explicit update to moving mean and standard deviation
        mean_update = K.moving_average_update(self.running_mean,
                                              mean_batch,
                                              self.momentum)
        variance_update = K.moving_average_update(self.running_variance,
                                                  std_batch ** 2,
                                                  self.momentum)
        self.add_update([mean_update, variance_update], inputs)

        # update r_max and d_max
        r_val = self.r_max_value / (1 + (self.r_max_value - 1) * K.exp(-self.t))
        d_val = (self.d_max_value
                 / (1 + ((self.d_max_value / 1e-3) - 1) * K.exp(-(2 * self.t))))

        self.add_update([K.update(self.r_max, r_val),
                         K.update(self.d_max, d_val),
                         K.update_add(self.t, self.t_delta_tensor)], inputs)

        if training in {0, False}:
            return x_normed
        else:
            def normalize_inference():
                if sorted(reduction_axes) == list(range(K.ndim(inputs)))[:-1]:
                    x_normed_running = K.batch_normalization(
                        inputs, self.running_mean, self.running_variance,
                        self.beta, self.gamma,
                        epsilon=self.epsilon)

                    return x_normed_running
                else:
                    # need broadcasting
                    broadcast_running_mean = K.reshape(self.running_mean,
                                                       broadcast_shape)
                    broadcast_running_std = K.reshape(self.running_variance,
                                                      broadcast_shape)
                    broadcast_beta = K.reshape(self.beta, broadcast_shape)
                    broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                    x_normed_running = K.batch_normalization(
                        inputs, broadcast_running_mean, broadcast_running_std,
                        broadcast_beta, broadcast_gamma,
                        epsilon=self.epsilon)

                    return x_normed_running

            # pick the normalized form of inputs corresponding to the training phase
            # for batch renormalization, inference time remains same as batchnorm
            x_normed = K.in_train_phase(x_normed, normalize_inference,
                                        training=training)

            return x_normed
Beispiel #4
0
    def call(self, inputs, training=None):
        old_shape = (tf.shape(inputs)[0], ) + K.int_shape(inputs)[1:]
        if self.group:
            new_shape = (tf.shape(inputs)[0], ) + old_shape[1:3] + (
                self.group, old_shape[self.axis] // self.group)
            inputs_reshape = K.reshape(inputs, shape=new_shape)
            inputs_reshape = K.permute_dimensions(inputs_reshape,
                                                  pattern=(0, 1, 2, 4, 3))
            input_shape = K.int_shape(inputs_reshape)
        else:
            inputs_reshape = inputs
            input_shape = old_shape
        # Prepare broadcasting shape.
        ndim = len(input_shape)
        reduction_axes = list(range(len(input_shape)))
        del reduction_axes[self.axis]
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        # Determines whether broadcasting is needed.
        needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

        def normalization(mean, variance):
            normed_training = (inputs_reshape - mean) / (K.sqrt(variance +
                                                                self.epsilon))

            # In this case we must explicitly broadcast all parameters.
            if self.scale:
                broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
                normed_training = normed_training * broadcast_gamma

            if self.center:
                broadcast_beta = K.reshape(self.beta, broadcast_shape)
                normed_training = normed_training + broadcast_beta

            return normed_training

        def normalize_inference():
            if needs_broadcasting:
                # In this case we must explicitly broadcast all parameters.
                broadcast_moving_mean = K.reshape(self.moving_mean,
                                                  broadcast_shape)
                broadcast_moving_variance = K.reshape(self.moving_variance,
                                                      broadcast_shape)
                outputs = normalization(broadcast_moving_mean,
                                        broadcast_moving_variance)
                return K.reshape(
                    K.permute_dimensions(outputs, (0, 1, 2, 4, 3)), old_shape)
            else:
                outputs = normalization(self.moving_mean, self.moving_variance)
                return K.reshape(
                    K.permute_dimensions(outputs, (0, 1, 2, 4, 3)), old_shape)

        # If the learning phase is *static* and set to inference:
        if training in {0, False}:
            return normalize_inference()

        # If the learning is either dynamic, or set to training:
        mean, variance = KC.moments(inputs_reshape, [1, 2, 3], keep_dims=True)
        normed_training = (inputs_reshape - mean) / (K.sqrt(variance +
                                                            self.epsilon))

        # In this case we must explicitly broadcast all parameters.
        if self.scale:
            broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
            normed_training = normed_training * broadcast_gamma

        if self.center:
            broadcast_beta = K.reshape(self.beta, broadcast_shape)
            normed_training = normed_training + broadcast_beta

        mean = K.mean(mean, [0, 1, 2, 3])
        variance = K.mean(variance, [0, 1, 2, 3])

        if K.backend() != 'cntk':
            sample_size = K.prod(
                [K.shape(inputs_reshape)[axis] for axis in reduction_axes])
            sample_size = K.cast(sample_size, dtype=K.dtype(inputs_reshape))
            if K.backend() == 'tensorflow' and sample_size.dtype != 'float32':
                sample_size = K.cast(sample_size, dtype='float32')

            # sample variance - unbiased estimator of population variance
            variance *= sample_size / (sample_size - (1.0 + self.epsilon))

        self.add_update([
            K.moving_average_update(self.moving_mean, mean, self.momentum),
            K.moving_average_update(self.moving_variance, variance,
                                    self.momentum)
        ], inputs)

        normed_training = K.reshape(
            K.permute_dimensions(normed_training, (0, 1, 2, 4, 3)), old_shape)

        # Pick the normalized form corresponding to the training phase.
        return K.in_train_phase(normed_training,
                                normalize_inference,
                                training=training)