def call(self, inputs, training=None): lam = inputs K = int(inputs.shape[-1] // 2) U = 2 if training: layer_loss = 0. # reshape weight for LWTA lam_re = tf.reshape(lam, [-1, K, U]) # calculate probability of activation and some stability operations prbs = tf.nn.softmax(lam_re) + 1e-4 prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True) # relaxed categorical sample xi = concrete_sample(prbs, 0.67) #apply activation out = lam_re * xi out = tf.reshape(out, tf.shape(input=lam)) # kl for the relaxed categorical variables kl_xi = tf.reduce_mean(input_tensor=tf.reduce_sum( input_tensor=concrete_kl(tf.ones([1, K, U]) / U, prbs, xi), axis=[1])) # print(kl_xi) #negative #something very small tf.compat.v1.add_to_collection('kl_loss', kl_xi) # self.add_loss(tf.math.reduce_mean(kl_xi)/60000) layer_loss = layer_loss + tf.math.reduce_mean(kl_xi) / 100000 tf.compat.v2.summary.scalar(name='kl_xi', data=kl_xi) else: layer_loss = 0. lam_re = tf.reshape(lam, [-1, K, U]) prbs = tf.nn.softmax(lam_re) + 1e-4 prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True) # apply activation out = lam_re * concrete_sample(prbs, 0.01) out = tf.reshape(out, tf.shape(input=lam)) self.add_loss(layer_loss) return out, prbs
def lwta_activation(x, temp, K, U, train = True): """ Implementation of the LWTA activation in a stochastic manner using the Gumbel Softmax trick. The computation is described in the paper Nonparametric Bayesian Deep Netowrks with Local Competition. @param x: tf.tensor, the input to the activation, i.e., the resulting tensor after conv operation @param temp: float, the temperature of the relaxation of the categorical distribution @param K: int, The number of LWTA blocks we consider @param U: int, the number of competitors in each block @param train: boolean, flag to choose between the train and test branches of the function. @return: tf.tensor, LWTA-activated input. tf.tensor, the KL divergence for the concrete relaxation. """ kl = 0 # reshape weight for LWTA x_reshaped = tf.reshape(x, [-1, K, U]) logits = x_reshaped xi = concrete_sample(logits, temp, hard = False) # apply activation out = x_reshaped * xi out = tf.reshape(out, tf.shape(input=x)) if train: q = tf.nn.softmax(logits) log_q = tf.math.log(q + 1e-8) kl = tf.reduce_sum(q*(log_q - tf.math.log(1.0/U)), [1]) kl = tf.reduce_mean(kl) return out, kl
def call(self, inputs, training=None): ksize = int(inputs.shape[-1] // 2) lam = inputs if training: layer_loss = 0. # reshape weight to calculate probabilities lam_re = tf.reshape( lam, [-1, lam.get_shape()[1], lam.get_shape()[2], ksize, 2]) prbs = tf.nn.softmax(lam_re) + 1e-5 prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True) # draw relaxed sample and apply activation xi = concrete_sample(prbs, 0.5) #apply activation out = lam_re * xi out = tf.reshape(out, tf.shape(input=lam)) # add the relative kl terms kl_xi = tf.reduce_mean(input_tensor=tf.reduce_sum( input_tensor=concrete_kl(tf.ones_like(lam_re) / 2, prbs, xi), axis=[1])) layer_loss = layer_loss + tf.math.reduce_mean(kl_xi) / 100000 else: layer_loss = 0. # calculate probabilities of activation lam_re = tf.reshape( lam, [-1, lam.get_shape()[1], lam.get_shape()[2], ksize, 2]) prbs = tf.nn.softmax(lam_re) + 1e-5 prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True) # draw sample for activated units out = lam_re * concrete_sample(prbs, 0.01) out = tf.reshape(out, tf.shape(input=lam)) self.add_loss(layer_loss) return out, prbs
def call(self, inputs, training=None): sW_softplus = tf.nn.softplus(self.sW) if training: # reparametrizable normal sample eps = tf.stop_gradient( tf.random.normal([inputs.get_shape()[1], self.K * self.U])) # W = self.mW + eps * self.sW W = self.mW + eps * sW_softplus z = 1. layer_loss = 0. #sbp if self.sbp == True: # posterior concentration variables for the IBP conc1_softplus = tf.nn.softplus(self.conc1) conc0_softplus = tf.nn.softplus(self.conc0) # stick breaking construction q_u = kumaraswamy_sample( conc1_softplus, conc0_softplus, sample_shape=[inputs.get_shape()[1], self.K]) pi = tf.math.cumprod(q_u) # posterior probabilities z t_pi_sigmoid = tf.nn.sigmoid(self.t_pi) # sample relaxed bernoulli z_sample = bin_concrete_sample(t_pi_sigmoid, self.temp_bern) z = tf.tile(z_sample, [1, self.U]) re = z * W # kl terms for the stick breaking construction kl_sticks = tf.reduce_sum(input_tensor=kumaraswamy_kl( tf.ones_like(conc1_softplus), tf.ones_like(conc0_softplus), conc1_softplus, conc0_softplus, q_u)) kl_z = tf.reduce_sum(input_tensor=bin_concrete_kl( pi, t_pi_sigmoid, self.temp_bern, z_sample)) tf.compat.v1.add_to_collection( 'kl_loss', kl_sticks) #positive something very big tf.compat.v1.add_to_collection( 'kl_loss', kl_z) #negative something very big # self.add_loss(tf.math.reduce_mean(kl_sticks)/60000) layer_loss = layer_loss + tf.math.reduce_mean( kl_sticks) / 60000 layer_loss = layer_loss + tf.math.reduce_mean(kl_z) / 60000 # self.add_loss(tf.math.reduce_mean(kl_z)/60000) tf.compat.v2.summary.scalar(name='kl_sticks', data=kl_sticks) tf.compat.v2.summary.scalar(name='kl_z', data=kl_z) # cut connections if probability of activation less than tau tf.compat.v2.summary.scalar( name='sparsity', data=tf.reduce_sum(input_tensor=tf.cast( tf.greater(t_pi_sigmoid / (1. + t_pi_sigmoid), self.tau), tf.float32)) * self.U) # sparsity = tf.reduce_sum(input_tensor=tf.cast(tf.greater(t_pi_sigmoid/(1.+t_pi_sigmoid), self.tau), tf.float32))*self.U else: re = W # add the kl for the weights to the collection # kl_weights = tf.reduce_sum(input_tensor=normal_kl(tf.zeros_like(self.mW), tf.ones_like(sW_softplus),self.mW, sW_softplus)) kl_weights = -0.5 * tf.reduce_mean( 2 * sW_softplus - tf.square(self.mW) - sW_softplus**2 + 1, name='kl_weights') tf.compat.v1.add_to_collection('kl_loss', kl_weights) #something very big # self.add_loss(tf.math.reduce_mean(kl_weights)/60000) layer_loss = layer_loss + tf.math.reduce_mean(kl_weights) / 60000 tf.compat.v2.summary.scalar(name='kl_weights', data=kl_weights) # dense calculation lam = tf.matmul(inputs, re) + self.biases if self.activation == 'lwta': assert self.U > 1, 'The number of competing units should be larger than 1' # reshape weight for LWTA lam_re = tf.reshape(lam, [-1, self.K, self.U]) # calculate probability of activation and some stability operations prbs = tf.nn.softmax(lam_re) + 1e-4 prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True) # relaxed categorical sample xi = concrete_sample(prbs, self.temp_cat) #apply activation out = lam_re * xi out = tf.reshape(out, tf.shape(input=lam)) # kl for the relaxed categorical variables kl_xi = tf.reduce_mean( input_tensor=tf.reduce_sum(input_tensor=concrete_kl( tf.ones([1, self.K, self.U]) / self.U, prbs, xi), axis=[1])) # print(kl_xi) #negative #something very small tf.compat.v1.add_to_collection('kl_loss', kl_xi) # self.add_loss(tf.math.reduce_mean(kl_xi)/60000) layer_loss = layer_loss + tf.math.reduce_mean(kl_xi) / 60000 tf.compat.v2.summary.scalar(name='kl_xi', data=kl_xi) elif self.activation == 'relu': out = tf.nn.relu(lam) elif self.activation == 'maxout': lam_re = tf.reshape(lam, [-1, self.K, self.U]) out = tf.reduce_max(input_tensor=lam_re, axis=-1) else: out = lam #test branch in the layer. It is activated automatically in the model. TF does the work ;) else: #this is very different from the original # we use re for accuracy and z for compression (if sbp is active) re = 1. z = 1. layer_loss = 0. #sbp if self.sbp == True: # posterior probabilities z t_pi_sigmoid = tf.nn.sigmoid(self.t_pi) mask = tf.cast(tf.greater(t_pi_sigmoid, self.tau), tf.float32) z = tfd.Bernoulli(probs=mask * t_pi_sigmoid, name="q_z_test", dtype=tf.float32).sample() z = tf.tile(z, [1, self.U]) re = tf.tile(mask * t_pi_sigmoid, [1, self.U]) lam = tf.matmul(inputs, re * self.mW) + self.biases if self.activation == 'lwta': # reshape and calulcate winners lam_re = tf.reshape(lam, [-1, self.K, self.U]) prbs = tf.nn.softmax(lam_re) + 1e-4 prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True) # apply activation out = lam_re * concrete_sample(prbs, 0.01) out = tf.reshape(out, tf.shape(input=lam)) elif self.activation == 'relu': out = tf.nn.relu(lam) elif self.activation == 'maxout': lam_re = tf.reshape(lam, [-1, self.K, self.U]) out = tf.reduce_max(input_tensor=lam_re, axis=-1) else: out = lam self.add_loss(layer_loss) # return out, self.mW, z*self.mW, z*self.sW**2, z return out
def call(self, inputs, training=None): sW_softplus = tf.nn.softplus(self.sW) if training: layer_loss = 0. z = 1. # reparametrizable normal sample eps = tf.stop_gradient(tf.random.normal(self.mW.get_shape())) W = self.mW + eps * sW_softplus re = tf.ones_like(W) # stick breaking construction if self.sbp == True: conc1_softplus = tf.nn.softplus(self.conc1) conc0_softplus = tf.nn.softplus(self.conc0) # stick breaking construction q_u = kumaraswamy_sample( conc1_softplus, conc0_softplus, sample_shape=[inputs.get_shape()[1], self.ksize[-2]]) pi = tf.math.cumprod(q_u) # posterior bernooulli (relaxed) probabilities t_pi_sigmoid = tf.nn.sigmoid(self.t_pi) z_sample = bin_concrete_sample(t_pi_sigmoid, self.temp_bern) z = tf.tile(z_sample, [self.ksize[-1]]) re = z * W kl_sticks = tf.reduce_sum( kumaraswamy_kl(tf.ones_like(conc1_softplus), tf.ones_like(conc0_softplus), conc1_softplus, conc0_softplus, q_u)) kl_z = tf.reduce_sum( bin_concrete_kl(pi, t_pi_sigmoid, self.temp_bern, z_sample)) tf.compat.v1.add_to_collection('kl_loss', kl_sticks) tf.compat.v1.add_to_collection('kl_loss', kl_z) layer_loss = layer_loss + tf.math.reduce_mean( kl_sticks) / 60000 layer_loss = layer_loss + tf.math.reduce_mean(kl_z) / 60000 tf.compat.v2.summary.scalar('kl_sticks', kl_sticks) tf.compat.v2.summary.scalar('kl_z', kl_z) # if probability of activation is smaller than tau, it's inactive tf.compat.v2.summary.scalar( 'sparsity', tf.reduce_sum( tf.cast( tf.greater(t_pi_sigmoid / (1. + t_pi_sigmoid), self.tau), tf.float32)) * self.ksize[-1]) # spasrity = tf.reduce_sum(tf.cast(tf.greater(t_pi_sigmoid/(1.+t_pi_sigmoid), self.tau), tf.float32))*self.ksize[-1] # add the kl terms to the collection # kl_weights = tf.reduce_sum(normal_kl(tf.zeros_like(self.mW), tf.ones_like(sW_softplus), \ # self.mW, sW_softplus, W)) kl_weights = -0.5 * tf.reduce_mean( 2 * sW_softplus - tf.square(self.mW) - sW_softplus**2 + 1, name='kl_weights') tf.compat.v1.add_to_collection('losses', kl_weights) tf.compat.v2.summary.scalar('kl_weights', kl_weights) layer_loss = layer_loss + tf.math.reduce_mean(kl_weights) / 60000 # convolution operation lam = tf.nn.conv2d(inputs, re, strides=(self.strides[0], self.strides[1]), padding=self.padding) + self.biases if self.activation == 'lwta': assert self.ksize[ -1] > 1, 'The number of competing units should be larger than 1' # reshape weight to calculate probabilities lam_re = tf.reshape(lam, [ -1, lam.get_shape()[1], lam.get_shape()[2], self.ksize[-2], self.ksize[-1] ]) prbs = tf.nn.softmax(lam_re) + 1e-5 prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True) # draw relaxed sample and apply activation xi = concrete_sample(prbs, self.temp_cat) #apply activation out = lam_re * xi out = tf.reshape(out, tf.shape(input=lam)) # add the relative kl terms kl_xi = tf.reduce_mean( input_tensor=tf.reduce_sum(input_tensor=concrete_kl( tf.ones_like(lam_re) / self.ksize[-1], prbs, xi), axis=[1])) tf.compat.v1.add_to_collection('kl_loss', kl_xi) tf.compat.v2.summary.scalar('kl_xi', kl_xi) layer_loss = layer_loss + tf.math.reduce_mean(kl_xi) / 60000 elif self.activation == 'relu': out = tf.nn.relu(lam) elif self.activation == 'maxout': lam_re = tf.reshape(lam, [ -1, lam.get_shape()[1], lam.get_shape()[2], self.ksize[-2], self.ksize[-1] ]) out = tf.reduce_max(lam_re, -1, keepdims=False) elif self.activation == 'none': out = lam else: print('Activation:', self.activation, 'not implemented.') out = lam else: re = tf.ones_like(self.mW) z = 1. layer_loss = 0. # if sbp is active calculate mask and draw samples if self.sbp: # posterior probabilities z t_pi_sigmoid = tf.nn.sigmoid(self.t_pi) mask = tf.cast(tf.greater(t_pi_sigmoid, self.tau), tf.float32) z = tfd.Bernoulli(probs=mask * t_pi_sigmoid, name="q_z_test", dtype=tf.float32).sample() z = tf.tile(z, [self.ksize[-1]]) re = tf.tile(mask * t_pi_sigmoid, [self.ksize[-1]]) # convolution operation lam = tf.nn.conv2d(inputs, re * self.mW, strides=(self.strides[0], self.strides[1]), padding=self.padding) + self.biases if self.activation == 'lwta': # calculate probabilities of activation lam_re = tf.reshape(lam, [ -1, lam.get_shape()[1], lam.get_shape()[2], self.ksize[-2], self.ksize[-1] ]) prbs = tf.nn.softmax(lam_re) + 1e-5 prbs /= tf.reduce_sum(input_tensor=prbs, axis=-1, keepdims=True) # draw sample for activated units out = lam_re * concrete_sample(prbs, 0.01) out = tf.reshape(out, tf.shape(input=lam)) elif self.activation == 'relu': # apply relu out = tf.nn.relu(lam) elif self.activation == 'maxout': # apply maxout operation lam_re = tf.reshape(lam, [ -1, lam.get_shape()[1], lam.get_shape()[2], self.ksize[-2], self.ksize[-1] ]) out = tf.reduce_max(input_tensor=lam_re, axis=-1) elif self.activation == 'none': out = lam else: print('Activation:', activation, ' not implemented.') out = lam self.add_loss(layer_loss) # return self.out, self.mW, self.z*self.mW, self.z*self.sW**2, self.z return out