def _apply_stats(self, statsUpdates, accumulate=False, accumulateCoeff=0.): updateOps = [] # obtain the stats var list for stats_var in statsUpdates: stats_new = statsUpdates[stats_var] if accumulate: # simple superbatch averaging update_op = tf.assign_add(stats_var, accumulateCoeff * stats_new, use_locking=True) else: # exponential running averaging update_op = tf.assign(stats_var, stats_var * self._stats_decay, use_locking=True) update_op = tf.assign_add(update_op, (1. - self._stats_decay) * stats_new, use_locking=True) updateOps.append(update_op) with tf.control_dependencies(updateOps): stats_step_op = tf.assign_add(self.stats_step, 1) if KFAC_DEBUG: stats_step_op = (tf.Print(stats_step_op, [ tf.convert_to_tensor('step:'), self.global_step, tf.convert_to_tensor('fac step:'), self.factor_step, tf.convert_to_tensor('sgd step:'), self.sgd_step, tf.convert_to_tensor('Accum:'), tf.convert_to_tensor(accumulate), tf.convert_to_tensor('Accum coeff:'), tf.convert_to_tensor(accumulateCoeff), tf.convert_to_tensor('stat step:'), self.stats_step, updateOps[0], updateOps[1] ])) return [ stats_step_op, ]
def applyStatsEigen(self, eigen_list): updateOps = [] print(('updating %d eigenvalue/vectors' % len(eigen_list))) for i, (tensor, mark) in enumerate(zip(eigen_list, self.eigen_update_list)): stats_eigen_var = self.eigen_reverse_lookup[mark] updateOps.append( tf.assign(stats_eigen_var, tensor, use_locking=True)) with tf.control_dependencies(updateOps): factor_step_op = tf.assign_add(self.factor_step, 1) updateOps.append(factor_step_op) if KFAC_DEBUG: updateOps.append( tf.Print(tf.constant(0.), [tf.convert_to_tensor('updated kfac factors')])) return updateOps
def coldSGDstart(): sgd_grads, sgd_var = zip(*grads) if self.max_grad_norm != None: sgd_grads, sgd_grad_norm = tf.clip_by_global_norm( sgd_grads, self.max_grad_norm) sgd_grads = list(zip(sgd_grads, sgd_var)) sgd_step_op = tf.assign_add(self.sgd_step, 1) coldOptim_op = coldOptim.apply_gradients(sgd_grads) if KFAC_DEBUG: with tf.control_dependencies([sgd_step_op, coldOptim_op]): sgd_step_op = tf.Print(sgd_step_op, [ self.sgd_step, tf.convert_to_tensor('doing cold sgd step') ]) return tf.group(*[sgd_step_op, coldOptim_op])
def no_op_wrapper(): return tf.group(*[tf.assign_add(self.cold_step, 1)])
def apply_gradients_kfac(self, grads): g, varlist = list(zip(*grads)) if len(self.stats_eigen) == 0: self.getStatsEigen() qr = None # launch eigen-decomp on a queue thread if self._async: print('Use async eigen decomp') # get a list of factor loading tensors factorOps_dummy = self.computeStatsEigen() # define a queue for the list of factor loading tensors queue = tf.FIFOQueue( 1, [item.dtype for item in factorOps_dummy], shapes=[item.get_shape() for item in factorOps_dummy]) enqueue_op = tf.cond( tf.logical_and( tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: queue.enqueue(self.computeStatsEigen()), tf.no_op) def dequeue_op(): return queue.dequeue() qr = tf.train.QueueRunner(queue, [enqueue_op]) updateOps = [] global_step_op = tf.assign_add(self.global_step, 1) updateOps.append(global_step_op) with tf.control_dependencies([global_step_op]): # compute updates assert self._update_stats_op != None updateOps.append(self._update_stats_op) dependency_list = [] if not self._async: dependency_list.append(self._update_stats_op) with tf.control_dependencies(dependency_list): def no_op_wrapper(): return tf.group(*[tf.assign_add(self.cold_step, 1)]) if not self._async: # synchronous eigen-decomp updates updateFactorOps = tf.cond( tf.logical_and( tf.equal( tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: tf.group(*self.applyStatsEigen( self.computeStatsEigen())), no_op_wrapper) else: # asynchronous eigen-decomp updates using queue updateFactorOps = tf.cond( tf.greater_equal(self.stats_step, self._stats_accum_iter), lambda: tf.cond( tf.equal(queue.size(), tf.convert_to_tensor(0)), tf.no_op, lambda: tf.group(*self.applyStatsEigen(dequeue_op( ))), ), no_op_wrapper) updateOps.append(updateFactorOps) with tf.control_dependencies([updateFactorOps]): def gradOp(): return list(g) def getKfacGradOp(): return self.getKfacPrecondUpdates(g, varlist) u = tf.cond( tf.greater(self.factor_step, tf.convert_to_tensor(0)), getKfacGradOp, gradOp) optim = tf.train.MomentumOptimizer( self._lr * (1. - self._momentum), self._momentum) #optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01) def optimOp(): def updateOptimOp(): if self._full_stats_init: return tf.cond( tf.greater(self.factor_step, tf.convert_to_tensor(0)), lambda: optim.apply_gradients( list(zip(u, varlist))), tf.no_op) else: return optim.apply_gradients( list(zip(u, varlist))) if self._full_stats_init: return tf.cond( tf.greater_equal(self.stats_step, self._stats_accum_iter), updateOptimOp, tf.no_op) else: return tf.cond( tf.greater_equal(self.sgd_step, self._cold_iter), updateOptimOp, tf.no_op) updateOps.append(optimOp()) return tf.group(*updateOps), qr