def compute_gradients(self, loss, var_list=None): varlist = var_list if varlist is None: varlist = tf.trainable_variables() g = tf.gradients(loss, varlist) return [(a, b) for a, b in zip(g, varlist)]
def compute_and_apply_stats(self, loss_sampled, var_list=None): varlist = var_list if varlist is None: varlist = tf.trainable_variables() stats = self.compute_stats(loss_sampled, var_list=varlist) return self.apply_stats(stats)
def __init__(self, loss_scalar, solver_name='adam', base_lr=None, lr_policy=None, momentum=None, weight_decay=None, fc_vars=None, last_conv_vars=None, vars_to_opt=None): self.base_lr = base_lr self.lr_policy = lr_policy self.momentum = momentum self.solver_name = solver_name self.loss_scalar = loss_scalar if self.lr_policy != 'fixed': raise NotImplementedError( 'learning rate policies other than fixed are not implemented') self.weight_decay = weight_decay if weight_decay is not None: if vars_to_opt is None: trainable_vars = tf.trainable_variables() else: trainable_vars = vars_to_opt loss_with_reg = self.loss_scalar for var in trainable_vars: loss_with_reg += self.weight_decay * tf.nn.l2_loss(var) self.loss_scalar = loss_with_reg self.solver_op = self.get_solver_op() if fc_vars is not None: self.fc_vars = fc_vars self.last_conv_vars = last_conv_vars self.fc_solver_op = self.get_solver_op(var_list=fc_vars)
def get_solver_op(self, var_list=None, loss=None): solver_string = self.solver_name.lower() if var_list is None: var_list = tf.trainable_variables() if loss is None: loss = self.loss_scalar if solver_string == 'adam': return tf.train.AdamOptimizer(learning_rate=self.base_lr, beta1=self.momentum).minimize( loss, var_list=var_list) elif solver_string == 'rmsprop': return tf.train.RMSPropOptimizer(learning_rate=self.base_lr, decay=self.momentum).minimize( loss, var_list=var_list) elif solver_string == 'momentum': return tf.train.MomentumOptimizer(learning_rate=self.base_lr, momentum=self.momentum).minimize( loss, var_list=var_list) elif solver_string == 'adagrad': return tf.train.AdagradOptimizer( learning_rate=self.base_lr, initial_accumulator_value=self.momentum).minimize( loss, var_list=var_list) elif solver_string == 'sgd': return tf.train.GradientDescentOptimizer( learning_rate=self.base_lr).minimize(loss, var_list=var_list) else: raise NotImplementedError("Please select a valid optimizer.")
def compute_stats(self, loss_sampled, var_list=None): varlist = var_list if varlist is None: varlist = tf.trainable_variables() gs = tf.gradients(loss_sampled, varlist, name='gradientsSampled') self.gs = gs factors = self.getFactors(gs, varlist) stats = self.getStats(factors, varlist) updateOps = [] statsUpdates = {} statsUpdates_cache = {} for var in varlist: opType = factors[var]['opName'] fops = factors[var]['op'] fpropFactor = factors[var]['fpropFactors_concat'] fpropStats_vars = stats[var]['fprop_concat_stats'] bpropFactor = factors[var]['bpropFactors_concat'] bpropStats_vars = stats[var]['bprop_concat_stats'] SVD_factors = {} for stats_var in fpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_fpropFactor = fpropFactor B = (tf.shape(fpropFactor)[0]) # batch size if opType == 'Conv2D': strides = fops.get_attr("strides") padding = fops.get_attr("padding") convkernel_size = var.get_shape()[0:3] KH = int(convkernel_size[0]) KW = int(convkernel_size[1]) C = int(convkernel_size[2]) flatten_size = int(KH * KW * C) Oh = int(bpropFactor.get_shape()[1]) Ow = int(bpropFactor.get_shape()[2]) if Oh == 1 and Ow == 1 and self._channel_fac: # factorization along the channels # assume independence among input channels # factor = B x 1 x 1 x (KH xKW x C) # patches = B x Oh x Ow x (KH xKW x C) if len(SVD_factors) == 0: if KFAC_DEBUG: print(( 'approx %s act factor with rank-1 SVD factors' % (var.name))) # find closest rank-1 approx to the feature map S, U, V = tf.batch_svd( tf.reshape(fpropFactor, [-1, KH * KW, C])) # get rank-1 approx slides sqrtS1 = tf.expand_dims(tf.sqrt(S[:, 0, 0]), 1) patches_k = U[:, :, 0] * sqrtS1 # B x KH*KW full_factor_shape = fpropFactor.get_shape() patches_k.set_shape( [full_factor_shape[0], KH * KW]) patches_c = V[:, :, 0] * sqrtS1 # B x C patches_c.set_shape([full_factor_shape[0], C]) SVD_factors[C] = patches_c SVD_factors[KH * KW] = patches_k fpropFactor = SVD_factors[stats_var_dim] else: # poor mem usage implementation patches = tf.extract_image_patches( fpropFactor, ksizes=[ 1, convkernel_size[0], convkernel_size[1], 1 ], strides=strides, rates=[1, 1, 1, 1], padding=padding) if self._approxT2: if KFAC_DEBUG: print(('approxT2 act fisher for %s' % (var.name))) # T^2 terms * 1/T^2, size: B x C fpropFactor = tf.reduce_mean(patches, [1, 2]) else: # size: (B x Oh x Ow) x C fpropFactor = tf.reshape( patches, [-1, flatten_size]) / Oh / Ow fpropFactor_size = int(fpropFactor.get_shape()[-1]) if stats_var_dim == (fpropFactor_size + 1) and not self._blockdiag_bias: if opType == 'Conv2D' and not self._approxT2: # correct padding for numerical stability (we # divided out OhxOw from activations for T1 approx) fpropFactor = tf.concat([ fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1]) / Oh / Ow ], 1) else: # use homogeneous coordinates fpropFactor = tf.concat([ fpropFactor, tf.ones([tf.shape(fpropFactor)[0], 1]) ], 1) # average over the number of data points in a batch # divided by B cov = tf.matmul(fpropFactor, fpropFactor, transpose_a=True) / tf.cast(B, tf.float32) updateOps.append(cov) statsUpdates[stats_var] = cov if opType != 'Conv2D': # HACK: for convolution we recompute fprop stats for # every layer including forking layers statsUpdates_cache[stats_var] = cov for stats_var in bpropStats_vars: stats_var_dim = int(stats_var.get_shape()[0]) if stats_var not in statsUpdates_cache: old_bpropFactor = bpropFactor bpropFactor_shape = bpropFactor.get_shape() B = tf.shape(bpropFactor)[0] # batch size C = int(bpropFactor_shape[-1]) # num channels if opType == 'Conv2D' or len(bpropFactor_shape) == 4: if fpropFactor is not None: if self._approxT2: if KFAC_DEBUG: print(('approxT2 grad fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum( bpropFactor, [1, 2]) # T^2 terms * 1/T^2 else: bpropFactor = tf.reshape( bpropFactor, [-1, C]) * Oh * Ow # T * 1/T terms else: # just doing block diag approx. spatial independent # structure does not apply here. summing over # spatial locations if KFAC_DEBUG: print(('block diag approx fisher for %s' % (var.name))) bpropFactor = tf.reduce_sum(bpropFactor, [1, 2]) # assume sampled loss is averaged. TO-DO:figure out better # way to handle this bpropFactor *= tf.to_float(B) ## cov_b = tf.matmul(bpropFactor, bpropFactor, transpose_a=True) / tf.to_float( tf.shape(bpropFactor)[0]) updateOps.append(cov_b) statsUpdates[stats_var] = cov_b statsUpdates_cache[stats_var] = cov_b if KFAC_DEBUG: aKey = list(statsUpdates.keys())[0] statsUpdates[aKey] = tf.Print(statsUpdates[aKey], [ tf.convert_to_tensor('step:'), self.global_step, tf.convert_to_tensor('computing stats'), ]) self.statsUpdates = statsUpdates return statsUpdates
def _serialize_variables(): sess = get_session() variables = tf.trainable_variables() values = sess.run(variables) return {var.name: value for var, value in zip(variables, values)}