def apply_softmax(self, wrapper): ub = self.upper lb = self.lower # Keep diagonal and take opposite bound for non-diagonals. lbs = tf.matrix_diag(lb) + tf.expand_dims(ub, axis=-2) - tf.matrix_diag(ub) ubs = tf.matrix_diag(ub) + tf.expand_dims(lb, axis=-2) - tf.matrix_diag(lb) # Get diagonal entries after softmax operation. ubs = tf.matrix_diag_part(tf.nn.softmax(ubs)) lbs = tf.matrix_diag_part(tf.nn.softmax(lbs)) return IntervalBounds(lbs, ubs)
def _slice_cov(self, cov): """ Slice the correct dimensions for use in the kernel, as indicated by `self.active_dims` for covariance matrices. This requires slicing the rows *and* columns. This will also turn flattened diagonal matrices into a tensor of full diagonal matrices. :param cov: Tensor of covariance matrices (NxDxD or NxD). :return: N x self.input_dim x self.input_dim. """ cov = tf.cond(tf.equal(tf.rank(cov), 2), lambda: tf.matrix_diag(cov), lambda: cov) if isinstance(self.active_dims, slice): cov = cov[..., self.active_dims, self.active_dims] else: cov_shape = tf.shape(cov) covr = tf.reshape(cov, [-1, cov_shape[-1], cov_shape[-1]]) gather1 = tf.gather(tf.transpose(covr, [2, 1, 0]), self.active_dims) gather2 = tf.gather(tf.transpose(gather1, [1, 0, 2]), self.active_dims) cov = tf.reshape( tf.transpose(gather2, [2, 0, 1]), tf.concat([ cov_shape[:-2], [len(self.active_dims), len(self.active_dims)] ], 0)) return cov
def log_prob_fn(params): rho, alpha, sigma = tf.split(params, [num_features, 1, 1], -1) one = tf.ones(num_features) def indep(d): return tfd.Independent(d, 1) p_rho = indep(tfd.InverseGamma(5. * one, 5. * one)) p_alpha = indep(tfd.HalfNormal([1.])) p_sigma = indep(tfd.HalfNormal([1.])) rho_shape = tf.shape(rho) alpha_shape = tf.shape(alpha) x1 = tf.expand_dims(x, -2) x2 = tf.expand_dims(x, -3) exp = -0.5 * tf.squared_difference(x1, x2) exp /= tf.reshape(tf.square(rho), tf.concat([rho_shape[:1], [1, 1], rho_shape[1:]], 0)) exp = tf.reduce_sum(exp, -1, keep_dims=True) exp += 2. * tf.reshape(tf.log(alpha), tf.concat([alpha_shape[:1], [1, 1], alpha_shape[1:]], 0)) exp = tf.exp(exp[Ellipsis, 0]) exp += tf.matrix_diag(tf.tile(tf.square(sigma), [1, int(x.shape[0])]) + 1e-6) exp = tf.check_numerics(exp, "exp 2 has NaNs") with tf.control_dependencies([tf.print(exp[0], summarize=99999)]): exp = tf.identity(exp) p_y = tfd.MultivariateNormalFullCovariance( covariance_matrix=exp) log_prob = ( p_rho.log_prob(rho) + p_alpha.log_prob(alpha) + p_sigma.log_prob(sigma) + p_y.log_prob(y)) return log_prob
def build(self, input_shape): input_shape = tf.TensorShape(input_shape) if dimension_value(input_shape[-1]) is None: raise ValueError("The last dimension of the inputs to `Dense` " "should be defined. Found `None`.") last_dim = dimension_value(input_shape[-1]) self.input_spec = tf.keras.layers.InputSpec(min_ndim=2, axes={-1: last_dim}) self._c = tf.get_variable( "c", [self._decoder_dim, self._rank], initializer=tf.contrib.layers.xavier_initializer(), regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) sigma = tf.matmul(self._mem_input, self._c) if self._sigma_norm > 0.: sigma = tf.nn.l2_normalize(sigma, axis=1) * self._sigma_norm elif self._sigma_norm == -1.: sigma = tf.nn.softmax(sigma / self._tau, axis=1) sigma_diag = tf.matrix_diag(sigma) self._u = tf.get_variable( "u", [last_dim, self._rank], initializer=tf.contrib.layers.xavier_initializer(), regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) self._v = tf.get_variable( "v", [self._rank, self.units], initializer=tf.contrib.layers.xavier_initializer(), regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, dtype=self.dtype, trainable=True) self.kernel = tf.einsum("ij,ajk,kl->ail", self._u, sigma_diag, self._v) if self._use_beam and self._beam_width: self.kernel = tf.contrib.seq2seq.tile_batch( self.kernel, multiplier=self._beam_width) if self.use_bias: self._b = self.add_weight("b", shape=[self.units, self._rank], initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, dtype=self.dtype, trainable=True) self.bias = tf.einsum("ij,aj->ai", self._b, sigma) if self._use_beam and self._beam_width: self.bias = tf.contrib.seq2seq.tile_batch( self.bias, multiplier=self._beam_width) else: self.bias = None self.built = True
def linear_certain_activations(x_certain, A, b): """ compute y = x^T A + b assuming x has zero variance """ x_mean = x_certain xx = x_mean * x_mean y_mean = tf.matmul(x_mean, A.mean) + b.mean y_cov = tf.matrix_diag(tf.matmul(xx, A.var) + b.var) return gv.GaussianVar(y_mean, y_cov)
def initialize(self): """Construct RNN params. """ sigma = tf.matmul(self._mem_input, self._c) if self._sigma_norm > 0.: sigma = tf.nn.l2_normalize(sigma, axis=1) * self._sigma_norm elif self._sigma_norm == -1.: sigma = tf.nn.softmax(sigma, axis=1) sigma_diag = tf.matrix_diag(sigma) # The weight matrices. # {`x`: input, `h`: hidden_state}. # {`i`: input_gate, `j`: cell_state, `f`: forget_gate, `o`: output_gate}. # Weight matrix that maps input `x` to input_gate `i`. w_xi = tf.einsum("ij,ajk,kl->ail", self._u_xi, sigma_diag, self._v_xi) # Weight matrix that maps input `x` to cell_state `j`. w_xj = tf.einsum("ij,ajk,kl->ail", self._u_xj, sigma_diag, self._v_xj) # Weight matrix that maps input `x` to forget_gate `f`. w_xf = tf.einsum("ij,ajk,kl->ail", self._u_xf, sigma_diag, self._v_xf) # Weight matrix that maps input `x` to output_gate `o`. w_xo = tf.einsum("ij,ajk,kl->ail", self._u_xo, sigma_diag, self._v_xo) # Weight matrix that maps hidden_state `h` to input_gate `i`. w_hi = tf.einsum("ij,ajk,kl->ail", self._u_hi, sigma_diag, self._v_hi) # Weight matrix that maps hidden_state `h` to cell_state `j`. w_hj = tf.einsum("ij,ajk,kl->ail", self._u_hj, sigma_diag, self._v_hj) # Weight matrix that maps hidden_state `h` to forget_gate `f`. w_hf = tf.einsum("ij,ajk,kl->ail", self._u_hf, sigma_diag, self._v_hf) # Weight matrix that maps hidden_state `h` to output_gate `o`. w_ho = tf.einsum("ij,ajk,kl->ail", self._u_ho, sigma_diag, self._v_ho) w_x = tf.concat([w_xi, w_xj, w_xf, w_xo], axis=2) w_h = tf.concat([w_hi, w_hj, w_hf, w_ho], axis=2) self._weight = tf.concat([w_x, w_h], axis=1) self._bias = tf.einsum("ij,aj->ai", self._b, sigma) if self._use_beam and self._beam_width > 1: self._weight = tf.contrib.seq2seq.tile_batch( self._weight, multiplier=self._beam_width) self._bias = tf.contrib.seq2seq.tile_batch( self._bias, multiplier=self._beam_width)
def __init__(self, priors, **kwargs): Prior.__init__(self) self.priors = priors self.name = kwargs.get("name", "FactPrior") self.nparams = len(priors) means = [prior.mean for prior in self.priors] variances = [prior.var for prior in self.priors] self.mean = self.log_tf(tf.stack(means, axis=-1, name="%s_mean" % self.name)) self.var = self.log_tf(tf.stack(variances, axis=-1, name="%s_var" % self.name)) self.std = tf.sqrt(self.var, name="%s_std" % self.name) self.nvertices = priors[0].nvertices # Define a diagonal covariance matrix for convenience self.cov = tf.matrix_diag(self.var, name='%s_cov' % self.name)
def regularizer(self, kl_loss, z_mean, z_logvar, z_sampled): cov_z_mean = compute_covariance_z_mean(z_mean) lambda_d = self.lambda_d_factor * self.lambda_od if self.dip_type == "i": # Eq 6 page 4 # mu = z_mean is [batch_size, num_latent] # Compute cov_p(x) [mu(x)] = E[mu*mu^T] - E[mu]E[mu]^T] cov_dip_regularizer = regularize_diag_off_diag_dip( cov_z_mean, self.lambda_od, lambda_d) elif self.dip_type == "ii": cov_enc = tf.matrix_diag(tf.exp(z_logvar)) expectation_cov_enc = tf.reduce_mean(cov_enc, axis=0) cov_z = expectation_cov_enc + cov_z_mean cov_dip_regularizer = regularize_diag_off_diag_dip( cov_z, self.lambda_od, lambda_d) else: raise NotImplementedError("DIP variant not supported.") return kl_loss + cov_dip_regularizer
def grad(grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" with tf.control_dependencies([grad_e, grad_v]): ediffs = tf.expand_dims(e, -2) - tf.expand_dims(e, -1) # Avoid NaNs from reciprocals when eigenvalues are close. safe_recip = tf.where(ediffs**2 < 1e-10, tf.zeros_like(ediffs), tf.reciprocal(ediffs)) f = tf.matrix_set_diag(safe_recip, tf.zeros_like(e)) grad_a = tf.matmul( v, tf.matmul(tf.matrix_diag(grad_e) + f * tf.matmul(v, grad_v, adjoint_a=True), v, adjoint_b=True)) # The forward op only depends on the lower triangular part of a, so here we # symmetrize and take the lower triangle grad_a = tf.linalg.band_part(grad_a + tf.linalg.adjoint(grad_a), -1, 0) grad_a = tf.linalg.set_diag(grad_a, 0.5 * tf.matrix_diag_part(grad_a)) return grad_a
def __init__(self, posts, **kwargs): FactorisedPosterior.__init__(self, posts, **kwargs) # The full covariance matrix is formed from the Cholesky decomposition # to ensure that it remains positive definite. # # To achieve this, we have to create PxP tensor variables for # each parameter vertex, but we then extract only the lower triangular # elements and train only on these. The diagonal elements # are constructed by the FactorisedPosterior if kwargs.get("init", None): # We are initializing from an existing posterior. # The FactorizedPosterior will already have extracted the mean and # diagonal of the covariance matrix - we need the Cholesky decomposition # of the covariance to initialize the off-diagonal terms self.log.info(" - Initializing posterior covariance from input posterior") _mean, cov = kwargs["init"] covar_init = tf.cholesky(cov) else: covar_init = tf.zeros([self.nvertices, self.nparams, self.nparams], dtype=tf.float32) self.off_diag_vars_base = self.log_tf(tf.Variable(covar_init, validate_shape=False, name='%s_off_diag_vars' % self.name)) if kwargs.get("suppress_nan", True): self.off_diag_vars = tf.where(tf.is_nan(self.off_diag_vars_base), tf.zeros_like(self.off_diag_vars_base), self.off_diag_vars_base) else: self.off_diag_vars = self.off_diag_vars_base self.off_diag_cov_chol = tf.matrix_set_diag(tf.matrix_band_part(self.off_diag_vars, -1, 0), tf.zeros([self.nvertices, self.nparams]), name='%s_off_diag_cov_chol' % self.name) # Combine diagonal and off-diagonal elements into full matrix self.cov_chol = tf.add(tf.matrix_diag(self.std), self.off_diag_cov_chol, name='%s_cov_chol' % self.name) # Form the covariance matrix from the chol decomposition self.cov = tf.matmul(tf.transpose(self.cov_chol, perm=(0, 2, 1)), self.cov_chol, name='%s_cov' % self.name) self.cov_chol = self.log_tf(self.cov_chol) self.cov = self.log_tf(self.cov)
def __init__(self, posts, **kwargs): Posterior.__init__(self, -1, **kwargs) self.posts = posts self.nparams = len(self.posts) self.name = kwargs.get("name", "FactPost") means = [post.mean for post in self.posts] variances = [post.var for post in self.posts] mean = tf.stack(means, axis=-1, name="%s_mean" % self.name) var = tf.stack(variances, axis=-1, name="%s_var" % self.name) self.mean = self.log_tf(tf.identity(mean, name="%s_mean" % self.name)) self.var = self.log_tf(tf.identity(var, name="%s_var" % self.name)) self.std = tf.sqrt(self.var, name="%s_std" % self.name) self.nvertices = posts[0].nvertices # Covariance matrix is diagonal self.cov = tf.matrix_diag(self.var, name='%s_cov' % self.name) # Regularisation to make sure cov is invertible. Note that we do not # need this for a diagonal covariance matrix but it is useful for # the full MVN covariance which shares some of the calculations self.cov_reg = 1e-5*tf.eye(self.nparams)
def test_fac_init_mvn(): """ Test factorised posterior initialized from a full MVN""" with tf.Session() as session: nparams_in = 4 nvoxels_in = 34 posts = [] means = np.random.normal(5.0, 3.0, [nvoxels_in, nparams_in]) variances = np.square(np.random.normal(2.5, 1.6, [nvoxels_in, nparams_in])) name = "TestFactorisedPosterior" for param in range(nparams_in): posts.append(NormalPosterior(means[:, param], variances[:, param])) mvn_post = MVNPosterior(posts, name=name) session.run(tf.global_variables_initializer()) init=(session.run(mvn_post.mean), session.run(mvn_post.cov)) fac_post = FactorisedPosterior(posts, init=init) session.run(tf.global_variables_initializer()) assert np.allclose(session.run(fac_post.mean), session.run(mvn_post.mean)) # Expect the covariance to just contain the diagonal elements diag_cov = tf.matrix_diag(tf.matrix_diag_part(mvn_post.cov)) assert np.allclose(session.run(fac_post.cov), session.run(diag_cov))
def sigma_tf(self, t, X, Y): # M x 1, M x D, M x 1 M = self.M D = self.D return tf.constant(self.snoise) * tf.matrix_diag(tf.ones([M, D]))
def multihead_invertible_1x1_conv_np(name, x, x_mask, multihead_split, inverse, dtype): """Multi-head 1X1 convolution on x.""" batch_size, length, n_channels_all = common_layers.shape_list(x) assert n_channels_all % 32 == 0 n_channels = 32 n_1x1_heads = n_channels_all // n_channels def get_init_np(): """Initializer function for multihead 1x1 parameters using numpy.""" results = [] for _ in range(n_1x1_heads): random_matrix = np.random.rand(n_channels, n_channels) np_w = scipy.linalg.qr(random_matrix)[0].astype("float32") np_p, np_l, np_u = scipy.linalg.lu(np_w) np_s = np.diag(np_u) np_sign_s = np.sign(np_s)[np.newaxis, :] np_log_s = np.log(np.abs(np_s))[np.newaxis, :] np_u = np.triu(np_u, k=1) results.append( np.concatenate([np_p, np_l, np_u, np_sign_s, np_log_s], axis=0)) return tf.convert_to_tensor(np.stack(results, axis=0)) def get_mask_init(): ones = tf.ones([n_1x1_heads, n_channels, n_channels], dtype=dtype) l_mask = tf.matrix_band_part(ones, -1, 0) - tf.matrix_band_part( ones, 0, 0) u_mask = tf.matrix_band_part(ones, 0, -1) - tf.matrix_band_part( ones, 0, 0) return tf.stack([l_mask, u_mask], axis=0) with tf.variable_scope(name, reuse=tf.AUTO_REUSE): params = tf.get_variable("params", initializer=get_init_np, dtype=dtype) mask_params = tf.get_variable("mask_params", initializer=get_mask_init, dtype=dtype, trainable=False) p = tf.stop_gradient(params[:, :n_channels, :]) l = params[:, n_channels:2 * n_channels, :] u = params[:, 2 * n_channels:3 * n_channels, :] sign_s = tf.stop_gradient(params[:, 3 * n_channels, :]) log_s = params[:, 3 * n_channels + 1, :] l_mask = mask_params[0] u_mask = mask_params[1] l_diag = l * l_mask + (tf.eye( n_channels, n_channels, [n_1x1_heads], dtype=dtype)) u_diag = u * u_mask + (tf.matrix_diag(sign_s * tf.exp(log_s))) w = tf.matmul(p, tf.matmul(l_diag, u_diag)) if multihead_split == "a": x = tf.reshape(x, [batch_size, length, n_channels, n_1x1_heads]) x = tf.transpose(x, [3, 0, 1, 2]) elif multihead_split == "c": x = tf.reshape(x, [batch_size, length, n_1x1_heads, n_channels]) x = tf.transpose(x, [2, 0, 1, 3]) else: raise ValueError("Multihead split not supported.") # [n_1x1_heads, batch_size, length, n_channels] if not inverse: # [n_1x1_heads, 1, n_channels, n_channels] x = tf.matmul(x, w[:, tf.newaxis, :, :]) else: w_inv = tf.matrix_inverse(w) x = tf.matmul(x, w_inv[:, tf.newaxis, :, :]) if multihead_split == "a": x = tf.transpose(x, [1, 2, 3, 0]) x = tf.reshape(x, [batch_size, length, n_channels * n_1x1_heads]) elif multihead_split == "c": x = tf.transpose(x, [1, 2, 0, 3]) x = tf.reshape(x, [batch_size, length, n_1x1_heads * n_channels]) else: raise ValueError("Multihead split not supported.") x_length = tf.reduce_sum(x_mask, -1) logabsdet = x_length * tf.reduce_sum(log_s) if inverse: logabsdet *= -1 return x, logabsdet
def sigma_tf(self, t, X, Y): # M x 1, M x D, M x 1 M = self.M D = self.D return tf.matrix_diag(tf.ones([M, D])) # M x D x D
def matrix_diag(x): return tf.matrix_diag(x)
def mmr_fc(z_list, y_true, W, n_in, n_hs, n_out, n_rb, n_db, gamma_rb, gamma_db, bs, q): """ Batch-wise implementation of the Maximum Margin Regularizer for fully-connected networks. Note that it is differentiable, and thus can be directly added to the main objective (e.g. the cross-entropy loss). z_list: list with all tensors that correspond to preactivation feature maps (in particular, z_list[-1] are logits; see models.MLP for details) y_true: one-hot encoded ground truth labels (bs x n_classes) W: list with all weight matrices n_in: total number of input pixels (e.g. 784 for MNIST) n_hs: list of number of hidden units for every hidden layer (e.g. [1024] for FC1) n_rb: number of closest region boundaries to take n_db: number of closest decision boundaries to take gamma_rb: gamma for region boundaries (approx. corresponds to the radius of the Lp-ball that we want to certify robustness in) gamma_db: gamma for decision boundaries (approx. corresponds to the radius of the Lp-ball that we want to be robust in) bs: batch size q: q-norm which is the dual norm to the p-norm that we aim to be robust at (e.g. if p=np.inf, q=1) """ eps_num_stabil = 1e-5 # epsilon for numerical stability in the denominator of the distances n_hl = len(W) - 1 # number of hidden layers y_pred = z_list[-1] relus = [] for i in range(n_hl): relu = tf.cast(tf.greater(z_list[i], 0), tf.float32) relus.append(tf.expand_dims(relu, 1)) dist_rb = tf.abs(z_list[0]) / tf.norm(W[0], axis=0, ord=q) # bs x n_hs[0] due to broadcasting V = tf.reshape(tf.tile(W[0], [bs, 1]), [bs, n_in, n_hs[0]]) # bs x d x n1 V = V * relus[0] # element-wise mult using broadcasting, result: bs x d x n_cur for i in range(1, n_hl): V = calc_v_fc(V, W[i]) # bs x d x n_hs[i] new_dist_rb = tf.abs(z_list[i]) / tf.norm(V, axis=1, ord=q) # bs x n_hs[i] dist_rb = tf.concat([dist_rb, new_dist_rb], 1) # bs x sum(n_hs[1:i]) V = V * relus[i] # element-wise mult using broadcasting, result: bs x d x n_cur th = zero_out_non_min_distances(dist_rb, n_rb) rb_term = tf.reduce_sum(th * tf.maximum(0.0, 1.0 - dist_rb / gamma_rb), axis=1) # decision boundaries V_last = calc_v_fc(V, W[-1]) y_true_diag = tf.matrix_diag(y_true) LLK2 = V_last @ y_true_diag # bs x d x K @ bs x K x K = bs x d x K l = tf.reduce_sum(LLK2, axis=2) # bs x d l = tf.tile(l, [1, n_out]) # bs x d x K l = tf.reshape(l, [-1, n_out, n_in]) # bs x K x d V_argmax = tf.transpose(l, perm=[0, 2, 1]) # bs x d x K diff_v = tf.abs(V_last - V_argmax) diff_v = diff_v + eps_num_stabil * tf.cast(tf.less(diff_v, eps_num_stabil), tf.float32) dist_db_denominator = tf.norm(diff_v, axis=1, ord=q) y_pred_diag = tf.expand_dims(y_pred, 1) y_pred_correct = y_pred_diag @ y_true_diag # bs x 1 x K @ bs x K x K = bs x 1 x K y_pred_correct = tf.reduce_sum(y_pred_correct, axis=2) # bs x 1 y_pred_correct = tf.tile(y_pred_correct, [1, n_out]) # bs x 1 x K y_pred_correct = tf.reshape(y_pred_correct, [-1, n_out, 1]) # bs x K x 1 y_pred_correct = tf.transpose(y_pred_correct, perm=[0, 2, 1]) # bs x 1 x K dist_db_numerator = tf.squeeze(y_pred_correct - y_pred_diag, 1) # bs x K dist_db_numerator = dist_db_numerator + 100.0 * y_true # bs x K dist_db = dist_db_numerator / dist_db_denominator + y_true * 2.0 * gamma_db th = zero_out_non_min_distances(dist_db, n_db) db_term = tf.reduce_sum(th * tf.maximum(0.0, 1.0 - dist_db / gamma_db), axis=1) return rb_term, db_term
def mmr_cnn(z_list, x, y_true, model, n_rb, n_db, gamma_rb, gamma_db, bs, q): """ Batch-wise implementation of the Maximum Margin Regularizer for CNNs as a TensorFlow computational graph. Note that it is differentiable, and thus can be directly added to the main objective (e.g. the cross-entropy loss). z_list: list with all tensors that correspond to preactivation feature maps (in particular, z_list[-1] are logits; see models.LeNetSmall for details) x: input points (bs x image_height x image_width x image_n_channels) y_true: one-hot encoded ground truth labels (bs x n_classes) model: models.CNN object that contains a model with its weights, strides, padding, etc n_rb: number of closest region boundaries to take n_db: number of closest decision boundaries to take gamma_rb: gamma for region boundaries (approx. corresponds to the radius of the Lp-ball that we want to certify robustness in) gamma_db: gamma for decision boundaries (approx. corresponds to the radius of the Lp-ball that we want to be robust in) bs: batch size q: q-norm which is the dual norm to the p-norm that we aim to be robust at (e.g. if p=np.inf, q=1) """ eps_num_stabil = 1e-5 # epsilon for numerical stability in the denominator of the distances # the padding and strides should be the same as in the forward pass conv strides, padding = model.strides, model.padding y_pred = z_list[-1] z_conv, z_fc, relus_conv, relus_fc, W_conv, W_fc = [], [], [], [], [], [] for w, y in zip(model.W, z_list): # Depending on the shape we form pre-activation values and their relu switches if len(y.shape) == 4: # if conv layer z_conv.append(y) relu = tf.cast(tf.greater(y, 0), tf.float32) relus_conv.append(tf.expand_dims(relu, 1)) W_conv.append(w) else: z_fc.append(y) relu = tf.cast(tf.greater(y, 0), tf.float32) relus_fc.append(tf.expand_dims(relu, 1)) W_fc.append(w) h_in, w_in, c_in = int(x.shape[1]), int(x.shape[2]), int(x.shape[3]) n_in = h_in * w_in * c_in n_out = y_true.shape[1] # z[0]: bs x h_next x w_next x n_next, W[0]: h_filter x w_filter x n_prev x n_next w_matrix = tf.reshape(W_conv[0], [-1, int(W_conv[0].shape[-1])]) # h_filter*w_filter*n_col x n_next denom = tf.norm(w_matrix, axis=0, ord=q, keep_dims=True) # n_next dist_rb = tf.abs(z_conv[0]) / denom # bs x h_next x w_next x n_next dist_rb = tf.reshape(dist_rb, [bs, int(z_conv[0].shape[1]*z_conv[0].shape[2]*z_conv[0].shape[3])]) # bs x h_next*w_next*n_next # We need to get the conv matrix. Instead of using loops to contruct such matrix, we can apply W[0] conv filter # to a reshaped identity matrix. Then we duplicate bs times the resulting tensor. identity_input_fm = tf.reshape(tf.eye(n_in, n_in), [1, n_in, h_in, w_in, c_in]) V = calc_v_conv(identity_input_fm, W_conv[0], strides[0], padding) # 1 x d x h_next x w_next x c_next V = tf.tile(V, [bs, 1, 1, 1, 1]) # bs x d x h_next x w_next x c_next V = V * relus_conv[0] for i in range(1, len(z_conv)): V = calc_v_conv(V, W_conv[i], strides[i], padding) # bs x d x h_next x w_next x c_next V_stable = V + eps_num_stabil * tf.cast(tf.less(tf.abs(V), eps_num_stabil), tf.float32) # note: +eps would also work new_dist_rb = tf.abs(z_conv[i]) / tf.norm(V_stable, axis=1, ord=q) # bs x h_next x w_next x c_next new_dist_rb = tf.reshape(new_dist_rb, [bs, z_conv[i].shape[1]*z_conv[i].shape[2]*z_conv[i].shape[3]]) # bs x h_next*w_next*c_next dist_rb = tf.concat([dist_rb, new_dist_rb], 1) # bs x sum(n_neurons[1:i]) V = V * relus_conv[i] # element-wise mult using broadcasting, result: bs x d x h_cur x w_cur x c_cur # Flattening after the last conv layer V = tf.reshape(V, [bs, n_in, V.shape[2] * V.shape[3] * V.shape[4]]) # bs x d x h_prev*w_prev*c_prev for i in range(len(z_fc) - 1): # the last layer requires special handling V = calc_v_fc(V, W_fc[i]) # bs x d x n_hs[i] V_stable = V + eps_num_stabil * tf.cast(tf.less(tf.abs(V), eps_num_stabil), tf.float32) new_dist_rb = tf.abs(z_fc[i]) / tf.norm(V_stable, axis=1, ord=q) # bs x n_hs[i] dist_rb = tf.concat([dist_rb, new_dist_rb], 1) # bs x sum(n_hs[1:i]) V = V * relus_fc[i] # element-wise mult using broadcasting, result: bs x d x n_cur th = zero_out_non_min_distances(dist_rb, n_rb) rb_term = tf.reduce_sum(th * tf.maximum(0.0, 1.0 - dist_rb / gamma_rb), axis=1) # decision boundaries V = calc_v_fc(V, W_fc[-1]) y_true_diag = tf.matrix_diag(y_true) LLK2 = V @ y_true_diag # bs x d x K @ bs x K x K = bs x d x K l = tf.reduce_sum(LLK2, axis=2) # bs x d l = tf.tile(l, [1, n_out]) # bs x d x K l = tf.reshape(l, [-1, n_out, n_in]) # bs x K x d V_argmax = tf.transpose(l, perm=[0, 2, 1]) # bs x d x K diff_v = tf.abs(V - V_argmax) diff_v = diff_v + eps_num_stabil * tf.cast(tf.less(diff_v, eps_num_stabil), tf.float32) dist_db_denominator = tf.norm(diff_v, axis=1, ord=q) y_pred_diag = tf.expand_dims(y_pred, 1) y_pred_correct = y_pred_diag @ y_true_diag # bs x 1 x K @ bs x K x K = bs x 1 x K y_pred_correct = tf.reduce_sum(y_pred_correct, axis=2) # bs x 1 y_pred_correct = tf.tile(y_pred_correct, [1, n_out]) # bs x 1 x K y_pred_correct = tf.reshape(y_pred_correct, [-1, n_out, 1]) # bs x K x 1 y_pred_correct = tf.transpose(y_pred_correct, perm=[0, 2, 1]) # bs x 1 x K dist_db_numerator = tf.squeeze(y_pred_correct - y_pred_diag, 1) # bs x K dist_db_numerator = dist_db_numerator + 100.0 * y_true # bs x K dist_db = dist_db_numerator / dist_db_denominator + y_true * 2.0 * gamma_db th = zero_out_non_min_distances(dist_db, n_db) db_term = tf.reduce_sum(th * tf.maximum(0.0, 1.0 - dist_db / gamma_db), axis=1) return rb_term, db_term