def applyStatsEigen(self, eigen_list): updateOps = [] print(('updating %d eigenvalue/vectors' % len(eigen_list))) for i, (tensor, mark) in enumerate(zip(eigen_list, self.eigen_update_list)): stats_eigen_var = self.eigen_reverse_lookup[mark] updateOps.append( tf.assign(stats_eigen_var, tensor, use_locking=True)) with tf.control_dependencies(updateOps): factor_step_op = tf.assign_add(self.factor_step, 1) updateOps.append(factor_step_op) if KFAC_DEBUG: updateOps.append( tf.Print(tf.constant(0.), [tf.convert_to_tensor('updated kfac factors')])) return updateOps
def coldSGDstart(): sgd_grads, sgd_var = zip(*grads) if self.max_grad_norm != None: sgd_grads, sgd_grad_norm = tf.clip_by_global_norm( sgd_grads, self.max_grad_norm) sgd_grads = list(zip(sgd_grads, sgd_var)) sgd_step_op = tf.assign_add(self.sgd_step, 1) coldOptim_op = coldOptim.apply_gradients(sgd_grads) if KFAC_DEBUG: with tf.control_dependencies([sgd_step_op, coldOptim_op]): sgd_step_op = tf.Print(sgd_step_op, [ self.sgd_step, tf.convert_to_tensor('doing cold sgd step') ]) return tf.group(*[sgd_step_op, coldOptim_op])
def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.): updateOps = [] # obtain the stats var list for stats_var in statsUpdates: stats_new = statsUpdates[stats_var] if accumulate: # simple superbatch averaging update_op = tf.assign_add(stats_var, accumulateCoeff * stats_new, use_locking=True) else: # exponential running averaging update_op = tf.assign(stats_var, stats_var * self._stats_decay, use_locking=True) update_op = tf.assign_add(update_op, (1. - self._stats_decay) * stats_new, use_locking=True) updateOps.append(update_op) with tf.control_dependencies(updateOps): stats_step_op = tf.assign_add(self.stats_step, 1) if KFAC_DEBUG: stats_step_op = (tf.Print(stats_step_op, [ tf.convert_to_tensor('step:'), self.global_step, tf.convert_to_tensor('fac step:'), self.factor_step, tf.convert_to_tensor('sgd step:'), self.sgd_step, tf.convert_to_tensor('Accum:'), tf.convert_to_tensor(accumulate), tf.convert_to_tensor('Accum coeff:'), tf.convert_to_tensor(accumulateCoeff), tf.convert_to_tensor('stat step:'), self.stats_step, updateOps[0], updateOps[1] ])) return [ stats_step_op, ]
def getKfacPrecondUpdates(self, gradlist, varlist): updatelist = [] vg = 0. assert len(self.stats) > 0 assert len(self.stats_eigen) > 0 assert len(self.factors) > 0 counter = 0 grad_dict = {var: grad for grad, var in zip(gradlist, varlist)} for grad, var in zip(gradlist, varlist): GRAD_RESHAPE = False GRAD_TRANSPOSE = False fpropFactoredFishers = self.stats[var]['fprop_concat_stats'] bpropFactoredFishers = self.stats[var]['bprop_concat_stats'] if (len(fpropFactoredFishers) + len(bpropFactoredFishers)) > 0: counter += 1 GRAD_SHAPE = grad.get_shape() if len(grad.get_shape()) > 2: # reshape conv kernel parameters KW = int(grad.get_shape()[0]) KH = int(grad.get_shape()[1]) C = int(grad.get_shape()[2]) D = int(grad.get_shape()[3]) if len(fpropFactoredFishers) > 1 and self._channel_fac: # reshape conv kernel parameters into tensor grad = tf.reshape(grad, [KW * KH, C, D]) else: # reshape conv kernel parameters into 2D grad grad = tf.reshape(grad, [-1, D]) GRAD_RESHAPE = True elif len(grad.get_shape()) == 1: # reshape bias or 1D parameters D = int(grad.get_shape()[0]) grad = tf.expand_dims(grad, 0) GRAD_RESHAPE = True else: # 2D parameters C = int(grad.get_shape()[0]) D = int(grad.get_shape()[1]) if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias: # use homogeneous coordinates only works for 2D grad. # TO-DO: figure out how to factorize bias grad # stack bias grad var_assnBias = self.stats[var]['assnBias'] grad = tf.concat( [grad, tf.expand_dims(grad_dict[var_assnBias], 0)], 0) # project gradient to eigen space and reshape the eigenvalues # for broadcasting eigVals = [] for idx, stats in enumerate( self.stats[var]['fprop_concat_stats']): Q = self.stats_eigen[stats]['Q'] e = detectMinVal(self.stats_eigen[stats]['e'], var, name='act', debug=KFAC_DEBUG) Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='act') eigVals.append(e) grad = gmatmul(Q, grad, transpose_a=True, reduce_dim=idx) for idx, stats in enumerate( self.stats[var]['bprop_concat_stats']): Q = self.stats_eigen[stats]['Q'] e = detectMinVal(self.stats_eigen[stats]['e'], var, name='grad', debug=KFAC_DEBUG) Q, e = factorReshape(Q, e, grad, facIndx=idx, ftype='grad') eigVals.append(e) grad = gmatmul(grad, Q, transpose_b=False, reduce_dim=idx) ## ##### # whiten using eigenvalues weightDecayCoeff = 0. if var in self._weight_decay_dict: weightDecayCoeff = self._weight_decay_dict[var] if KFAC_DEBUG: print(('weight decay coeff for %s is %f' % (var.name, weightDecayCoeff))) if self._factored_damping: if KFAC_DEBUG: print(('use factored damping for %s' % (var.name))) coeffs = 1. num_factors = len(eigVals) # compute the ratio of two trace norm of the left and right # KFac matrices, and their generalization if len(eigVals) == 1: damping = self._epsilon + weightDecayCoeff else: damping = tf.pow(self._epsilon + weightDecayCoeff, 1. / num_factors) eigVals_tnorm_avg = [ tf.reduce_mean(tf.abs(e)) for e in eigVals ] for e, e_tnorm in zip(eigVals, eigVals_tnorm_avg): eig_tnorm_negList = [ item for item in eigVals_tnorm_avg if item != e_tnorm ] if len(eigVals) == 1: adjustment = 1. elif len(eigVals) == 2: adjustment = tf.sqrt(e_tnorm / eig_tnorm_negList[0]) else: eig_tnorm_negList_prod = reduce( lambda x, y: x * y, eig_tnorm_negList) adjustment = tf.pow( tf.pow(e_tnorm, num_factors - 1.) / eig_tnorm_negList_prod, 1. / num_factors) coeffs *= (e + adjustment * damping) else: coeffs = 1. damping = (self._epsilon + weightDecayCoeff) for e in eigVals: coeffs *= e coeffs += damping #grad = tf.Print(grad, [tf.convert_to_tensor('1'), tf.convert_to_tensor(var.name), grad.get_shape()]) grad /= coeffs #grad = tf.Print(grad, [tf.convert_to_tensor('2'), tf.convert_to_tensor(var.name), grad.get_shape()]) ##### # project gradient back to euclidean space for idx, stats in enumerate( self.stats[var]['fprop_concat_stats']): Q = self.stats_eigen[stats]['Q'] grad = gmatmul(Q, grad, transpose_a=False, reduce_dim=idx) for idx, stats in enumerate( self.stats[var]['bprop_concat_stats']): Q = self.stats_eigen[stats]['Q'] grad = gmatmul(grad, Q, transpose_b=True, reduce_dim=idx) ## #grad = tf.Print(grad, [tf.convert_to_tensor('3'), tf.convert_to_tensor(var.name), grad.get_shape()]) if (self.stats[var]['assnBias'] is not None) and not self._blockdiag_bias: # use homogeneous coordinates only works for 2D grad. # TO-DO: figure out how to factorize bias grad # un-stack bias grad var_assnBias = self.stats[var]['assnBias'] C_plus_one = int(grad.get_shape()[0]) grad_assnBias = tf.reshape( tf.slice(grad, begin=[C_plus_one - 1, 0], size=[1, -1]), var_assnBias.get_shape()) grad_assnWeights = tf.slice(grad, begin=[0, 0], size=[C_plus_one - 1, -1]) grad_dict[var_assnBias] = grad_assnBias grad = grad_assnWeights #grad = tf.Print(grad, [tf.convert_to_tensor('4'), tf.convert_to_tensor(var.name), grad.get_shape()]) if GRAD_RESHAPE: grad = tf.reshape(grad, GRAD_SHAPE) grad_dict[var] = grad print(('projecting %d gradient matrices' % counter)) for g, var in zip(gradlist, varlist): grad = grad_dict[var] ### clipping ### if KFAC_DEBUG: print(('apply clipping to %s' % (var.name))) tf.Print(grad, [tf.sqrt(tf.reduce_sum(tf.pow(grad, 2)))], "Euclidean norm of new grad") local_vg = tf.reduce_sum(grad * g * (self._lr * self._lr)) vg += local_vg # recale everything if KFAC_DEBUG: print('apply vFv clipping') scaling = tf.minimum(1., tf.sqrt(self._clip_kl / vg)) if KFAC_DEBUG: scaling = tf.Print(scaling, [ tf.convert_to_tensor('clip: '), scaling, tf.convert_to_tensor(' vFv: '), vg ]) with tf.control_dependencies([tf.assign(self.vFv, vg)]): updatelist = [grad_dict[var] for var in varlist] for i, item in enumerate(updatelist): updatelist[i] = scaling * item return updatelist
def computeStatsEigen(self): """ compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """ # TO-DO: figure out why this op has delays (possibly moving # eigenvectors around?) with tf.device('/cpu:0'): def removeNone(tensor_list): local_list = [] for item in tensor_list: if item is not None: local_list.append(item) return local_list def copyStats(var_list): print("copying stats to buffer tensors before eigen decomp") redundant_stats = {} copied_list = [] for item in var_list: if item is not None: if item not in redundant_stats: if self._use_float64: redundant_stats[item] = tf.cast( tf.identity(item), tf.float64) else: redundant_stats[item] = tf.identity(item) copied_list.append(redundant_stats[item]) else: copied_list.append(None) return copied_list #stats = [copyStats(self.fStats), copyStats(self.bStats)] #stats = [self.fStats, self.bStats] stats_eigen = self.stats_eigen computedEigen = {} eigen_reverse_lookup = {} updateOps = [] # sync copied stats # with tf.control_dependencies(removeNone(stats[0]) + # removeNone(stats[1])): with tf.control_dependencies([]): for stats_var in stats_eigen: if stats_var not in computedEigen: eigens = tf.self_adjoint_eig(stats_var) e = eigens[0] Q = eigens[1] if self._use_float64: e = tf.cast(e, tf.float32) Q = tf.cast(Q, tf.float32) updateOps.append(e) updateOps.append(Q) computedEigen[stats_var] = {'e': e, 'Q': Q} eigen_reverse_lookup[e] = stats_eigen[stats_var]['e'] eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q'] self.eigen_reverse_lookup = eigen_reverse_lookup self.eigen_update_list = updateOps if KFAC_DEBUG: self.eigen_update_list = [item for item in updateOps] with tf.control_dependencies(updateOps): updateOps.append( tf.Print( tf.constant(0.), [tf.convert_to_tensor('computed factor eigen')])) return updateOps
def compute_stats(self, loss_sampled, var_list=None): varlist = var_list if varlist is None: varlist = tf.trainable_variables() gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled') self.gs = gs factors = self.getFactors(gs, varlist) stats = self.getStats(factors, varlist) updateOps = [] statsUpdates = {} statsUpdates_cache = {} for var in varlist: opType = factors[var]['opName'] fops = factors[var]['op'] fpropFactor = factors[var]['fpropFactors_concat'] fpropStats_vars = stats[var]['fprop_concat_stats'] bpropFactor = factors[var]['bpropFactors_concat'] bpropStats_vars = stats[var]['bprop_concat_stats'] SVD_factors = {} for stats_var in fpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_fpropFactor = fpropFactor B = (tf.shape(fpropFactor)[0]) # batch size if opType == 'Conv2D': strides = fops.get_attr("strides") padding = fops.get_attr("padding") convkernel_size = var.get_shape()[0:3] KH = int(convkernel_size[0]) KW = int(convkernel_size[1]) C = int(convkernel_size[2]) flatten_size = int(KH * KW * C) Oh = int(bpropFactor.get_shape()[1]) Ow = int(bpropFactor.get_shape()[2]) if Oh == 1 and Ow == 1 and self._channel_fac: # factorization along the channels # assume independence among input channels # factor = B x 1 x 1 x (KH xKW x C) # patches = B x Oh x Ow x (KH xKW x C) if len(SVD_factors) == 0: if KFAC_DEBUG: print(( 'approx %s act factor with rank-1 SVD factors' % (var.name))) # find closest rank-1 approx to the feature map S, U, V = tf.batch_svd( tf.reshape(fpropFactor, [-1, KH * KW, C])) # get rank-1 approx slides sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1) patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW full_factor_shape = fpropFactor.get_shape() patches_k.set_shape( [full_factor_shape[0], KH * KW]) patches_c = V[:, :, 0] * sqrtS1 # B x C patches_c.set_shape([full_factor_shape[0], C]) SVD_factors[C] = patches_c SVD_factors[KH * KW] = patches_k fpropFactor = SVD_factors[stats_var_dim] else: # poor mem usage implementation patches = tf.extract_image_patches( fpropFactor, ksizes=[ 1, convkernel_size[0], convkernel_size[1], 1 ], strides=strides, rates=[1, 1, 1, 1], padding=padding) if self._approxT2: if KFAC_DEBUG: print(('approxT2 act fisher for %s' % (var.name))) # T^2 terms * 1/T^2, size: B x C fpropFactor = tf.reduce_mean(patches, [1, 2]) else: # size: (B x Oh x Ow) x C fpropFactor = tf.reshape( patches, [-1, flatten_size]) / Oh / Ow fpropFactor_size = int(fpropFactor.get_shape()[-1]) if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias: if opType == 'Conv2D' and not self._approxT2: # correct padding for numerical stability (we # divided out OhxOw from activations for T1 approx) fpropFactor = tf.concat([ fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1]) / Oh / Ow ], 1) else: # use homogeneous coordinates fpropFactor = tf.concat([ fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1]) ], 1) # average over the number of data points in a batch # divided by B cov = tf.matmul(fpropFactor, fpropFactor, transpose_a=True) / tf.cast(B, tf.float32) updateOps.append(cov) statsUpdates[stats_var] = cov if opType != 'Conv2D': # HACK: for convolution we recompute fprop stats for # every layer including forking layers statsUpdates_cache[stats_var] = cov for stats_var in bpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_bpropFactor = bpropFactor bpropFactor_shape = bpropFactor.get_shape() B = tf.shape(bpropFactor)[0] # batch size C = int(bpropFactor_shape[-1]) # num channels if opType == 'Conv2D' or len(bpropFactor_shape) == 4: if fpropFactor is not None: if self._approxT2: if KFAC_DEBUG: print(('approxT2 grad fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum( bpropFactor, [1, 2]) # T^2 terms * 1/T^2 else: bpropFactor = tf.reshape( bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms else: # just doing block diag approx. spatial independent # structure does not apply here. summing over # spatial locations if KFAC_DEBUG: print(('block diag approx fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum(bpropFactor, [1, 2]) # assume sampled loss is averaged. TO-DO:figure out better # way to handle this bpropFactor *= tf.to_float(B) ## cov_b = tf.matmul(bpropFactor, bpropFactor, transpose_a=True) / tf.to_float( tf.shape(bpropFactor)[0]) updateOps.append(cov_b) statsUpdates[stats_var] = cov_b statsUpdates_cache[stats_var] = cov_b if KFAC_DEBUG: aKey = list(statsUpdates.keys())[0] statsUpdates[aKey] = tf.Print(statsUpdates[aKey], [ tf.convert_to_tensor('step:'), self.global_step, tf.convert_to_tensor('computing stats'), ]) self.statsUpdates = statsUpdates return statsUpdates
def detectMinVal(input_mat, var, threshold=1e-6, name='', debug=False): eigen_min = tf.reduce_min(input_mat) eigen_max = tf.reduce_max(input_mat) eigen_ratio = eigen_max / eigen_min input_mat_clipped = clipoutNeg(input_mat, threshold) if debug: input_mat_clipped = tf.cond(tf.logical_or(tf.greater(eigen_ratio, 0.), tf.less(eigen_ratio, -500)), lambda: input_mat_clipped, lambda: tf.Print( input_mat_clipped, [tf.convert_to_tensor('screwed ratio ' + name + ' eigen values!!!'), tf.convert_to_tensor(var.name), eigen_min, eigen_max, eigen_ratio])) return input_mat_clipped