def reshape_for_broadcasting(source, target): """Reshapes a tensor (source) to have the correct shape and dtype of the target before broadcasting it with MPI. """ dim = len(target.get_shape()) shape = ([1] * (dim - 1)) + [-1] return tf.reshape(tf.cast(source, target.dtype), shape)
def copyStats(var_list): print("copying stats to buffer tensors before eigen decomp") redundant_stats = {} copied_list = [] for item in var_list: if item is not None: if item not in redundant_stats: if self._use_float64: redundant_stats[item] = tf.cast( tf.identity(item), tf.float64) else: redundant_stats[item] = tf.identity(item) copied_list.append(redundant_stats[item]) else: copied_list.append(None) return copied_list
def computeStatsEigen(self): """ compute the eigen decomp using copied var stats to avoid concurrent read/write from other queue """ # TO-DO: figure out why this op has delays (possibly moving # eigenvectors around?) with tf.device('/cpu:0'): def removeNone(tensor_list): local_list = [] for item in tensor_list: if item is not None: local_list.append(item) return local_list def copyStats(var_list): print("copying stats to buffer tensors before eigen decomp") redundant_stats = {} copied_list = [] for item in var_list: if item is not None: if item not in redundant_stats: if self._use_float64: redundant_stats[item] = tf.cast( tf.identity(item), tf.float64) else: redundant_stats[item] = tf.identity(item) copied_list.append(redundant_stats[item]) else: copied_list.append(None) return copied_list #stats = [copyStats(self.fStats), copyStats(self.bStats)] #stats = [self.fStats, self.bStats] stats_eigen = self.stats_eigen computedEigen = {} eigen_reverse_lookup = {} updateOps = [] # sync copied stats # with tf.control_dependencies(removeNone(stats[0]) + # removeNone(stats[1])): with tf.control_dependencies([]): for stats_var in stats_eigen: if stats_var not in computedEigen: eigens = tf.self_adjoint_eig(stats_var) e = eigens[0] Q = eigens[1] if self._use_float64: e = tf.cast(e, tf.float32) Q = tf.cast(Q, tf.float32) updateOps.append(e) updateOps.append(Q) computedEigen[stats_var] = {'e': e, 'Q': Q} eigen_reverse_lookup[e] = stats_eigen[stats_var]['e'] eigen_reverse_lookup[Q] = stats_eigen[stats_var]['Q'] self.eigen_reverse_lookup = eigen_reverse_lookup self.eigen_update_list = updateOps if KFAC_DEBUG: self.eigen_update_list = [item for item in updateOps] with tf.control_dependencies(updateOps): updateOps.append( tf.Print( tf.constant(0.), [tf.convert_to_tensor('computed factor eigen')])) return updateOps
def compute_stats(self, loss_sampled, var_list=None): varlist = var_list if varlist is None: varlist = tf.trainable_variables() gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled') self.gs = gs factors = self.getFactors(gs, varlist) stats = self.getStats(factors, varlist) updateOps = [] statsUpdates = {} statsUpdates_cache = {} for var in varlist: opType = factors[var]['opName'] fops = factors[var]['op'] fpropFactor = factors[var]['fpropFactors_concat'] fpropStats_vars = stats[var]['fprop_concat_stats'] bpropFactor = factors[var]['bpropFactors_concat'] bpropStats_vars = stats[var]['bprop_concat_stats'] SVD_factors = {} for stats_var in fpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_fpropFactor = fpropFactor B = (tf.shape(fpropFactor)[0]) # batch size if opType == 'Conv2D': strides = fops.get_attr("strides") padding = fops.get_attr("padding") convkernel_size = var.get_shape()[0:3] KH = int(convkernel_size[0]) KW = int(convkernel_size[1]) C = int(convkernel_size[2]) flatten_size = int(KH * KW * C) Oh = int(bpropFactor.get_shape()[1]) Ow = int(bpropFactor.get_shape()[2]) if Oh == 1 and Ow == 1 and self._channel_fac: # factorization along the channels # assume independence among input channels # factor = B x 1 x 1 x (KH xKW x C) # patches = B x Oh x Ow x (KH xKW x C) if len(SVD_factors) == 0: if KFAC_DEBUG: print(( 'approx %s act factor with rank-1 SVD factors' % (var.name))) # find closest rank-1 approx to the feature map S, U, V = tf.batch_svd( tf.reshape(fpropFactor, [-1, KH * KW, C])) # get rank-1 approx slides sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1) patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW full_factor_shape = fpropFactor.get_shape() patches_k.set_shape( [full_factor_shape[0], KH * KW]) patches_c = V[:, :, 0] * sqrtS1 # B x C patches_c.set_shape([full_factor_shape[0], C]) SVD_factors[C] = patches_c SVD_factors[KH * KW] = patches_k fpropFactor = SVD_factors[stats_var_dim] else: # poor mem usage implementation patches = tf.extract_image_patches( fpropFactor, ksizes=[ 1, convkernel_size[0], convkernel_size[1], 1 ], strides=strides, rates=[1, 1, 1, 1], padding=padding) if self._approxT2: if KFAC_DEBUG: print(('approxT2 act fisher for %s' % (var.name))) # T^2 terms * 1/T^2, size: B x C fpropFactor = tf.reduce_mean(patches, [1, 2]) else: # size: (B x Oh x Ow) x C fpropFactor = tf.reshape( patches, [-1, flatten_size]) / Oh / Ow fpropFactor_size = int(fpropFactor.get_shape()[-1]) if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias: if opType == 'Conv2D' and not self._approxT2: # correct padding for numerical stability (we # divided out OhxOw from activations for T1 approx) fpropFactor = tf.concat([ fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1]) / Oh / Ow ], 1) else: # use homogeneous coordinates fpropFactor = tf.concat([ fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1]) ], 1) # average over the number of data points in a batch # divided by B cov = tf.matmul(fpropFactor, fpropFactor, transpose_a=True) / tf.cast(B, tf.float32) updateOps.append(cov) statsUpdates[stats_var] = cov if opType != 'Conv2D': # HACK: for convolution we recompute fprop stats for # every layer including forking layers statsUpdates_cache[stats_var] = cov for stats_var in bpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_bpropFactor = bpropFactor bpropFactor_shape = bpropFactor.get_shape() B = tf.shape(bpropFactor)[0] # batch size C = int(bpropFactor_shape[-1]) # num channels if opType == 'Conv2D' or len(bpropFactor_shape) == 4: if fpropFactor is not None: if self._approxT2: if KFAC_DEBUG: print(('approxT2 grad fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum( bpropFactor, [1, 2]) # T^2 terms * 1/T^2 else: bpropFactor = tf.reshape( bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms else: # just doing block diag approx. spatial independent # structure does not apply here. summing over # spatial locations if KFAC_DEBUG: print(('block diag approx fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum(bpropFactor, [1, 2]) # assume sampled loss is averaged. TO-DO:figure out better # way to handle this bpropFactor *= tf.to_float(B) ## cov_b = tf.matmul(bpropFactor, bpropFactor, transpose_a=True) / tf.to_float( tf.shape(bpropFactor)[0]) updateOps.append(cov_b) statsUpdates[stats_var] = cov_b statsUpdates_cache[stats_var] = cov_b if KFAC_DEBUG: aKey = list(statsUpdates.keys())[0] statsUpdates[aKey] = tf.Print(statsUpdates[aKey], [ tf.convert_to_tensor('step:'), self.global_step, tf.convert_to_tensor('computing stats'), ]) self.statsUpdates = statsUpdates return statsUpdates
def clipoutNeg(vec, threshold=1e-6): mask = tf.cast(vec > threshold, tf.float32) return mask * vec
def sample(self): return tf.cast( tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)