def apply_gradients(self, grads): coldOptim = tf.train.MomentumOptimizer(self._cold_lr, self._momentum) def coldSGDstart(): sgd_grads, sgd_var = zip(*grads) if self.max_grad_norm != None: sgd_grads, sgd_grad_norm = tf.clip_by_global_norm( sgd_grads, self.max_grad_norm) sgd_grads = list(zip(sgd_grads, sgd_var)) sgd_step_op = tf.assign_add(self.sgd_step, 1) coldOptim_op = coldOptim.apply_gradients(sgd_grads) if KFAC_DEBUG: with tf.control_dependencies([sgd_step_op, coldOptim_op]): sgd_step_op = tf.Print(sgd_step_op, [ self.sgd_step, tf.convert_to_tensor('doing cold sgd step') ]) return tf.group(*[sgd_step_op, coldOptim_op]) kfacOptim_op, qr = self.apply_gradients_kfac(grads) def warmKFACstart(): return kfacOptim_op return tf.cond(tf.greater(self.sgd_step, self._cold_iter), warmKFACstart, coldSGDstart), qr
def updateOptimOp(): if self._full_stats_init: return tf.cond( tf.greater(self.factor_step, tf.convert_to_tensor(0)), lambda: optim.apply_gradients( list(zip(u, varlist))), tf.no_op) else: return optim.apply_gradients( list(zip(u, varlist)))
def detectMinVal(input_mat, var, threshold=1e-6, name='', debug=False): eigen_min = tf.reduce_min(input_mat) eigen_max = tf.reduce_max(input_mat) eigen_ratio = eigen_max / eigen_min input_mat_clipped = clipoutNeg(input_mat, threshold) if debug: input_mat_clipped = tf.cond(tf.logical_or(tf.greater(eigen_ratio, 0.), tf.less(eigen_ratio, -500)), lambda: input_mat_clipped, lambda: tf.Print( input_mat_clipped, [tf.convert_to_tensor('screwed ratio ' + name + ' eigen values!!!'), tf.convert_to_tensor(var.name), eigen_min, eigen_max, eigen_ratio])) return input_mat_clipped
def updateAccumStats(): if self._full_stats_init: return tf.cond( tf.greater(self.sgd_step, self._cold_iter), lambda: tf.group(*self._apply_stats( statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter)), tf.no_op) else: return tf.group(*self._apply_stats(statsUpdates, accumulate=True, accumulateCoeff=1. / self._stats_accum_iter))
def apply_gradients_kfac(self, grads): g, varlist = list(zip(*grads)) if len(self.stats_eigen) == 0: self.getStatsEigen() qr = None # launch eigen-decomp on a queue thread if self._async: print('Use async eigen decomp') # get a list of factor loading tensors factorOps_dummy = self.computeStatsEigen() # define a queue for the list of factor loading tensors queue = tf.FIFOQueue( 1, [item.dtype for item in factorOps_dummy], shapes=[item.get_shape() for item in factorOps_dummy]) enqueue_op = tf.cond( tf.logical_and( tf.equal(tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: queue.enqueue(self.computeStatsEigen()), tf.no_op) def dequeue_op(): return queue.dequeue() qr = tf.train.QueueRunner(queue, [enqueue_op]) updateOps = [] global_step_op = tf.assign_add(self.global_step, 1) updateOps.append(global_step_op) with tf.control_dependencies([global_step_op]): # compute updates assert self._update_stats_op != None updateOps.append(self._update_stats_op) dependency_list = [] if not self._async: dependency_list.append(self._update_stats_op) with tf.control_dependencies(dependency_list): def no_op_wrapper(): return tf.group(*[tf.assign_add(self.cold_step, 1)]) if not self._async: # synchronous eigen-decomp updates updateFactorOps = tf.cond( tf.logical_and( tf.equal( tf.mod(self.stats_step, self._kfac_update), tf.convert_to_tensor(0)), tf.greater_equal(self.stats_step, self._stats_accum_iter)), lambda: tf.group(*self.applyStatsEigen( self.computeStatsEigen())), no_op_wrapper) else: # asynchronous eigen-decomp updates using queue updateFactorOps = tf.cond( tf.greater_equal(self.stats_step, self._stats_accum_iter), lambda: tf.cond( tf.equal(queue.size(), tf.convert_to_tensor(0)), tf.no_op, lambda: tf.group(*self.applyStatsEigen(dequeue_op( ))), ), no_op_wrapper) updateOps.append(updateFactorOps) with tf.control_dependencies([updateFactorOps]): def gradOp(): return list(g) def getKfacGradOp(): return self.getKfacPrecondUpdates(g, varlist) u = tf.cond( tf.greater(self.factor_step, tf.convert_to_tensor(0)), getKfacGradOp, gradOp) optim = tf.train.MomentumOptimizer( self._lr * (1. - self._momentum), self._momentum) #optim = tf.train.AdamOptimizer(self._lr, epsilon=0.01) def optimOp(): def updateOptimOp(): if self._full_stats_init: return tf.cond( tf.greater(self.factor_step, tf.convert_to_tensor(0)), lambda: optim.apply_gradients( list(zip(u, varlist))), tf.no_op) else: return optim.apply_gradients( list(zip(u, varlist))) if self._full_stats_init: return tf.cond( tf.greater_equal(self.stats_step, self._stats_accum_iter), updateOptimOp, tf.no_op) else: return tf.cond( tf.greater_equal(self.sgd_step, self._cold_iter), updateOptimOp, tf.no_op) updateOps.append(optimOp()) return tf.group(*updateOps), qr