Example #1
0
    def call(self, inputs, training=None):

        lam = inputs
        K = int(inputs.shape[-1] // 2)
        U = 2

        if training:
            layer_loss = 0.

            # reshape weight for LWTA
            lam_re = tf.reshape(lam, [-1, K, U])

            # calculate probability of activation and some stability operations
            prbs = tf.nn.softmax(lam_re) + 1e-4
            prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True)

            # relaxed categorical sample
            xi = concrete_sample(prbs, 0.67)

            #apply activation
            out = lam_re * xi
            out = tf.reshape(out, tf.shape(input=lam))

            # kl for the relaxed categorical variables
            kl_xi = tf.reduce_mean(input_tensor=tf.reduce_sum(
                input_tensor=concrete_kl(tf.ones([1, K, U]) / U, prbs, xi),
                axis=[1]))
            # print(kl_xi) #negative #something very small
            tf.compat.v1.add_to_collection('kl_loss', kl_xi)
            # self.add_loss(tf.math.reduce_mean(kl_xi)/60000)
            layer_loss = layer_loss + tf.math.reduce_mean(kl_xi) / 100000
            tf.compat.v2.summary.scalar(name='kl_xi', data=kl_xi)

        else:

            layer_loss = 0.

            lam_re = tf.reshape(lam, [-1, K, U])
            prbs = tf.nn.softmax(lam_re) + 1e-4
            prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True)

            # apply activation
            out = lam_re * concrete_sample(prbs, 0.01)
            out = tf.reshape(out, tf.shape(input=lam))

        self.add_loss(layer_loss)

        return out, prbs
Example #2
0
def lwta_activation(x, temp, K, U, train = True):
    """
    Implementation of the LWTA activation in a stochastic manner using the Gumbel Softmax trick.
     The computation is described in the paper Nonparametric Bayesian Deep Netowrks with Local Competition.

    @param x: tf.tensor, the input to the activation, i.e., the resulting tensor after conv operation
    @param temp: float, the temperature of the relaxation of the categorical distribution
    @param K: int, The number of LWTA blocks we consider
    @param U: int, the number of competitors in each block
    @param train: boolean, flag to choose between the train and test branches of the function.

    @return: tf.tensor, LWTA-activated input.
             tf.tensor, the KL divergence for the concrete relaxation.
    """

    kl = 0

    # reshape weight for LWTA
    x_reshaped = tf.reshape(x, [-1, K, U])
    logits = x_reshaped

    xi = concrete_sample(logits, temp, hard = False)

    # apply activation
    out = x_reshaped * xi
    out = tf.reshape(out, tf.shape(input=x))

    if train:
        q = tf.nn.softmax(logits)
        log_q = tf.math.log(q + 1e-8)
        kl = tf.reduce_sum(q*(log_q - tf.math.log(1.0/U)), [1])
        kl = tf.reduce_mean(kl)

    return out, kl
    def call(self, inputs, training=None):

        ksize = int(inputs.shape[-1] // 2)
        lam = inputs

        if training:
            layer_loss = 0.
            # reshape weight to calculate probabilities
            lam_re = tf.reshape(
                lam, [-1, lam.get_shape()[1],
                      lam.get_shape()[2], ksize, 2])

            prbs = tf.nn.softmax(lam_re) + 1e-5
            prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True)

            # draw relaxed sample and apply activation
            xi = concrete_sample(prbs, 0.5)

            #apply activation
            out = lam_re * xi
            out = tf.reshape(out, tf.shape(input=lam))

            # add the relative kl terms
            kl_xi = tf.reduce_mean(input_tensor=tf.reduce_sum(
                input_tensor=concrete_kl(tf.ones_like(lam_re) / 2, prbs, xi),
                axis=[1]))

            layer_loss = layer_loss + tf.math.reduce_mean(kl_xi) / 100000

        else:

            layer_loss = 0.

            # calculate probabilities of activation
            lam_re = tf.reshape(
                lam, [-1, lam.get_shape()[1],
                      lam.get_shape()[2], ksize, 2])
            prbs = tf.nn.softmax(lam_re) + 1e-5
            prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True)

            # draw sample for activated units
            out = lam_re * concrete_sample(prbs, 0.01)
            out = tf.reshape(out, tf.shape(input=lam))

        self.add_loss(layer_loss)

        return out, prbs
Example #4
0
    def call(self, inputs, training=None):

        sW_softplus = tf.nn.softplus(self.sW)

        if training:

            # reparametrizable normal sample
            eps = tf.stop_gradient(
                tf.random.normal([inputs.get_shape()[1], self.K * self.U]))
            # W = self.mW + eps * self.sW
            W = self.mW + eps * sW_softplus

            z = 1.
            layer_loss = 0.

            #sbp
            if self.sbp == True:

                # posterior concentration variables for the IBP
                conc1_softplus = tf.nn.softplus(self.conc1)
                conc0_softplus = tf.nn.softplus(self.conc0)

                # stick breaking construction
                q_u = kumaraswamy_sample(
                    conc1_softplus,
                    conc0_softplus,
                    sample_shape=[inputs.get_shape()[1], self.K])
                pi = tf.math.cumprod(q_u)

                # posterior probabilities z
                t_pi_sigmoid = tf.nn.sigmoid(self.t_pi)

                # sample relaxed bernoulli
                z_sample = bin_concrete_sample(t_pi_sigmoid, self.temp_bern)
                z = tf.tile(z_sample, [1, self.U])
                re = z * W

                # kl terms for the stick breaking construction
                kl_sticks = tf.reduce_sum(input_tensor=kumaraswamy_kl(
                    tf.ones_like(conc1_softplus), tf.ones_like(conc0_softplus),
                    conc1_softplus, conc0_softplus, q_u))
                kl_z = tf.reduce_sum(input_tensor=bin_concrete_kl(
                    pi, t_pi_sigmoid, self.temp_bern, z_sample))

                tf.compat.v1.add_to_collection(
                    'kl_loss', kl_sticks)  #positive something very big
                tf.compat.v1.add_to_collection(
                    'kl_loss', kl_z)  #negative something very big
                # self.add_loss(tf.math.reduce_mean(kl_sticks)/60000)
                layer_loss = layer_loss + tf.math.reduce_mean(
                    kl_sticks) / 60000
                layer_loss = layer_loss + tf.math.reduce_mean(kl_z) / 60000
                # self.add_loss(tf.math.reduce_mean(kl_z)/60000)

                tf.compat.v2.summary.scalar(name='kl_sticks', data=kl_sticks)
                tf.compat.v2.summary.scalar(name='kl_z', data=kl_z)

                # cut connections if probability of activation less than tau
                tf.compat.v2.summary.scalar(
                    name='sparsity',
                    data=tf.reduce_sum(input_tensor=tf.cast(
                        tf.greater(t_pi_sigmoid /
                                   (1. + t_pi_sigmoid), self.tau), tf.float32))
                    * self.U)
                # sparsity = tf.reduce_sum(input_tensor=tf.cast(tf.greater(t_pi_sigmoid/(1.+t_pi_sigmoid), self.tau), tf.float32))*self.U

            else:
                re = W

            # add the kl for the weights to the collection
            # kl_weights = tf.reduce_sum(input_tensor=normal_kl(tf.zeros_like(self.mW), tf.ones_like(sW_softplus),self.mW, sW_softplus))
            kl_weights = -0.5 * tf.reduce_mean(
                2 * sW_softplus - tf.square(self.mW) - sW_softplus**2 + 1,
                name='kl_weights')

            tf.compat.v1.add_to_collection('kl_loss',
                                           kl_weights)  #something very big
            # self.add_loss(tf.math.reduce_mean(kl_weights)/60000)
            layer_loss = layer_loss + tf.math.reduce_mean(kl_weights) / 60000
            tf.compat.v2.summary.scalar(name='kl_weights', data=kl_weights)

            # dense calculation
            lam = tf.matmul(inputs, re) + self.biases

            if self.activation == 'lwta':
                assert self.U > 1, 'The number of competing units should be larger than 1'

                # reshape weight for LWTA
                lam_re = tf.reshape(lam, [-1, self.K, self.U])

                # calculate probability of activation and some stability operations
                prbs = tf.nn.softmax(lam_re) + 1e-4
                prbs /= tf.reduce_sum(input_tensor=prbs,
                                      axis=-1,
                                      keepdims=True)

                # relaxed categorical sample
                xi = concrete_sample(prbs, self.temp_cat)

                #apply activation
                out = lam_re * xi
                out = tf.reshape(out, tf.shape(input=lam))

                # kl for the relaxed categorical variables
                kl_xi = tf.reduce_mean(
                    input_tensor=tf.reduce_sum(input_tensor=concrete_kl(
                        tf.ones([1, self.K, self.U]) / self.U, prbs, xi),
                                               axis=[1]))
                # print(kl_xi) #negative #something very small
                tf.compat.v1.add_to_collection('kl_loss', kl_xi)
                # self.add_loss(tf.math.reduce_mean(kl_xi)/60000)
                layer_loss = layer_loss + tf.math.reduce_mean(kl_xi) / 60000
                tf.compat.v2.summary.scalar(name='kl_xi', data=kl_xi)

            elif self.activation == 'relu':
                out = tf.nn.relu(lam)
            elif self.activation == 'maxout':
                lam_re = tf.reshape(lam, [-1, self.K, self.U])
                out = tf.reduce_max(input_tensor=lam_re, axis=-1)
            else:
                out = lam

        #test branch in the layer. It is activated automatically in the model. TF does the work ;)
        else:
            #this is very different from the original
            # we use re for accuracy and z for compression (if sbp is active)
            re = 1.
            z = 1.
            layer_loss = 0.
            #sbp
            if self.sbp == True:

                # posterior probabilities z
                t_pi_sigmoid = tf.nn.sigmoid(self.t_pi)

                mask = tf.cast(tf.greater(t_pi_sigmoid, self.tau), tf.float32)
                z = tfd.Bernoulli(probs=mask * t_pi_sigmoid,
                                  name="q_z_test",
                                  dtype=tf.float32).sample()
                z = tf.tile(z, [1, self.U])

                re = tf.tile(mask * t_pi_sigmoid, [1, self.U])

            lam = tf.matmul(inputs, re * self.mW) + self.biases

            if self.activation == 'lwta':

                # reshape and calulcate winners
                lam_re = tf.reshape(lam, [-1, self.K, self.U])
                prbs = tf.nn.softmax(lam_re) + 1e-4
                prbs /= tf.reduce_sum(input_tensor=prbs,
                                      axis=-1,
                                      keepdims=True)

                # apply activation
                out = lam_re * concrete_sample(prbs, 0.01)
                out = tf.reshape(out, tf.shape(input=lam))

            elif self.activation == 'relu':

                out = tf.nn.relu(lam)

            elif self.activation == 'maxout':

                lam_re = tf.reshape(lam, [-1, self.K, self.U])
                out = tf.reduce_max(input_tensor=lam_re, axis=-1)

            else:
                out = lam
        self.add_loss(layer_loss)
        # return out, self.mW, z*self.mW, z*self.sW**2, z
        return out
    def call(self, inputs, training=None):

        sW_softplus = tf.nn.softplus(self.sW)

        if training:

            layer_loss = 0.
            z = 1.
            # reparametrizable normal sample
            eps = tf.stop_gradient(tf.random.normal(self.mW.get_shape()))
            W = self.mW + eps * sW_softplus

            re = tf.ones_like(W)

            # stick breaking construction
            if self.sbp == True:

                conc1_softplus = tf.nn.softplus(self.conc1)
                conc0_softplus = tf.nn.softplus(self.conc0)

                # stick breaking construction
                q_u = kumaraswamy_sample(
                    conc1_softplus,
                    conc0_softplus,
                    sample_shape=[inputs.get_shape()[1], self.ksize[-2]])
                pi = tf.math.cumprod(q_u)

                # posterior bernooulli (relaxed) probabilities
                t_pi_sigmoid = tf.nn.sigmoid(self.t_pi)

                z_sample = bin_concrete_sample(t_pi_sigmoid, self.temp_bern)
                z = tf.tile(z_sample, [self.ksize[-1]])
                re = z * W

                kl_sticks = tf.reduce_sum(
                    kumaraswamy_kl(tf.ones_like(conc1_softplus),
                                   tf.ones_like(conc0_softplus),
                                   conc1_softplus, conc0_softplus, q_u))
                kl_z = tf.reduce_sum(
                    bin_concrete_kl(pi, t_pi_sigmoid, self.temp_bern,
                                    z_sample))

                tf.compat.v1.add_to_collection('kl_loss', kl_sticks)
                tf.compat.v1.add_to_collection('kl_loss', kl_z)

                layer_loss = layer_loss + tf.math.reduce_mean(
                    kl_sticks) / 60000
                layer_loss = layer_loss + tf.math.reduce_mean(kl_z) / 60000

                tf.compat.v2.summary.scalar('kl_sticks', kl_sticks)
                tf.compat.v2.summary.scalar('kl_z', kl_z)

                # if probability of activation is smaller than tau, it's inactive
                tf.compat.v2.summary.scalar(
                    'sparsity',
                    tf.reduce_sum(
                        tf.cast(
                            tf.greater(t_pi_sigmoid / (1. + t_pi_sigmoid),
                                       self.tau), tf.float32)) *
                    self.ksize[-1])
                # spasrity = tf.reduce_sum(tf.cast(tf.greater(t_pi_sigmoid/(1.+t_pi_sigmoid), self.tau), tf.float32))*self.ksize[-1]

            # add the kl terms to the collection
            # kl_weights = tf.reduce_sum(normal_kl(tf.zeros_like(self.mW), tf.ones_like(sW_softplus), \
            # self.mW, sW_softplus, W))

            kl_weights = -0.5 * tf.reduce_mean(
                2 * sW_softplus - tf.square(self.mW) - sW_softplus**2 + 1,
                name='kl_weights')

            tf.compat.v1.add_to_collection('losses', kl_weights)
            tf.compat.v2.summary.scalar('kl_weights', kl_weights)

            layer_loss = layer_loss + tf.math.reduce_mean(kl_weights) / 60000

            # convolution operation
            lam = tf.nn.conv2d(inputs,
                               re,
                               strides=(self.strides[0], self.strides[1]),
                               padding=self.padding) + self.biases

            if self.activation == 'lwta':
                assert self.ksize[
                    -1] > 1, 'The number of competing units should be larger than 1'

                # reshape weight to calculate probabilities
                lam_re = tf.reshape(lam, [
                    -1,
                    lam.get_shape()[1],
                    lam.get_shape()[2], self.ksize[-2], self.ksize[-1]
                ])

                prbs = tf.nn.softmax(lam_re) + 1e-5
                prbs /= tf.reduce_sum(input_tensor=prbs,
                                      axis=-1,
                                      keepdims=True)

                # draw relaxed sample and apply activation
                xi = concrete_sample(prbs, self.temp_cat)

                #apply activation
                out = lam_re * xi
                out = tf.reshape(out, tf.shape(input=lam))

                # add the relative kl terms
                kl_xi = tf.reduce_mean(
                    input_tensor=tf.reduce_sum(input_tensor=concrete_kl(
                        tf.ones_like(lam_re) / self.ksize[-1], prbs, xi),
                                               axis=[1]))

                tf.compat.v1.add_to_collection('kl_loss', kl_xi)
                tf.compat.v2.summary.scalar('kl_xi', kl_xi)

                layer_loss = layer_loss + tf.math.reduce_mean(kl_xi) / 60000

            elif self.activation == 'relu':
                out = tf.nn.relu(lam)
            elif self.activation == 'maxout':
                lam_re = tf.reshape(lam, [
                    -1,
                    lam.get_shape()[1],
                    lam.get_shape()[2], self.ksize[-2], self.ksize[-1]
                ])
                out = tf.reduce_max(lam_re, -1, keepdims=False)
            elif self.activation == 'none':
                out = lam
            else:
                print('Activation:', self.activation, 'not implemented.')
                out = lam

        else:

            re = tf.ones_like(self.mW)
            z = 1.
            layer_loss = 0.

            # if sbp is active calculate mask and draw samples
            if self.sbp:

                # posterior probabilities z
                t_pi_sigmoid = tf.nn.sigmoid(self.t_pi)

                mask = tf.cast(tf.greater(t_pi_sigmoid, self.tau), tf.float32)
                z = tfd.Bernoulli(probs=mask * t_pi_sigmoid,
                                  name="q_z_test",
                                  dtype=tf.float32).sample()
                z = tf.tile(z, [self.ksize[-1]])
                re = tf.tile(mask * t_pi_sigmoid, [self.ksize[-1]])

            # convolution operation
            lam = tf.nn.conv2d(inputs,
                               re * self.mW,
                               strides=(self.strides[0], self.strides[1]),
                               padding=self.padding) + self.biases

            if self.activation == 'lwta':
                # calculate probabilities of activation
                lam_re = tf.reshape(lam, [
                    -1,
                    lam.get_shape()[1],
                    lam.get_shape()[2], self.ksize[-2], self.ksize[-1]
                ])
                prbs = tf.nn.softmax(lam_re) + 1e-5
                prbs /= tf.reduce_sum(input_tensor=prbs,
                                      axis=-1,
                                      keepdims=True)

                # draw sample for activated units
                out = lam_re * concrete_sample(prbs, 0.01)
                out = tf.reshape(out, tf.shape(input=lam))

            elif self.activation == 'relu':
                # apply relu
                out = tf.nn.relu(lam)

            elif self.activation == 'maxout':
                # apply maxout operation
                lam_re = tf.reshape(lam, [
                    -1,
                    lam.get_shape()[1],
                    lam.get_shape()[2], self.ksize[-2], self.ksize[-1]
                ])
                out = tf.reduce_max(input_tensor=lam_re, axis=-1)
            elif self.activation == 'none':
                out = lam
            else:
                print('Activation:', activation, ' not implemented.')
                out = lam

        self.add_loss(layer_loss)
        # return self.out, self.mW, self.z*self.mW, self.z*self.sW**2, self.z
        return out