def sessrun(*args, **kwargs): sess = u.get_default_session() if not GLOBAL_PROFILE: return sess.run(*args, **kwargs) run_metadata = tf.RunMetadata() kwargs['options'] = full_trace_options kwargs['run_metadata'] = run_metadata result = sess.run(*args, **kwargs) first_entry = args[0] if isinstance(first_entry, list): if len(first_entry) == 0 and len(args) == 1: return None first_entry = first_entry[0] name = first_entry.name name = name.replace('/', '-') tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() with open('timelines/%s.json' % (name, ), 'w') as f: f.write(ctf) with open('timelines/%s.pbtxt' % (name, ), 'w') as f: f.write(str(run_metadata)) return result
def sessrun(*args, **kwargs): sess = u.get_default_session() if not GLOBAL_PROFILE: return sess.run(*args, **kwargs) run_metadata = tf.RunMetadata() kwargs['options'] = full_trace_options kwargs['run_metadata'] = run_metadata result = sess.run(*args, **kwargs) first_entry = args[0] if isinstance(first_entry, list): if len(first_entry) == 0 and len(args) == 1: return None first_entry = first_entry[0] name = first_entry.name name = name.replace('/', '-') tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() with open('timelines/%s.json'%(name,), 'w') as f: f.write(ctf) with open('timelines/%s.pbtxt'%(name,), 'w') as f: f.write(str(run_metadata)) return result
def advance_batch(): # print("Step for model(%s) is %s"%(model.name, u.eval(model.step))) sess = u.get_default_session() # TODO: get rid of _sampled_labels sessrun([sampled_labels.initializer, _sampled_labels.initializer]) if args.advance_batch: sessrun(advance_batch_op) sessrun(advance_step_op)
def __init__(self, model_creator, dsize): """ model_creator: function that creates Model object stats_dsize is used to control the batch size of model used to collect kfac statistics.""" s = self # use for private members, ie, s.some_internal_val s.lock = threading.Lock() # most broad lock covering all corrections s.model = model_creator(dsize, name="kfac") if do_early_init: s.model.initialize_local_vars() s.model.initialize_global_vars() s.log = OrderedDict() # regular gradient s.grad = IndexedGrad(loss=s.model.loss, vars_=s.model.trainable_vars) # gradient with synthetic backprops s.grad2 = IndexedGrad(loss=s.model.loss2, vars_=s.model.trainable_vars) s.lr = VarStruct(initial_value=-np.inf, name="lr", dtype=dtype) s.Lambda = VarStruct(initial_value=-np.inf, name="Lambda", dtype=dtype) # covariance and SVD ops for all correctable ops, mapped to parameter # variable to correct s.kfac_correction_dict = OrderedDict() for var in s.model.trainable_vars: if not s.needs_correction(var): continue A = s.extract_A(s.grad2, var) B2 = s.extract_B(s.grad2, var) # todo: change to extract_B s.register_correction(var) dsize_op = tf.constant(dsize, dtype=dtype) s[var].A = Covariance(A, var, "A", s.Lambda.var) # dsize is already incorporated as part of backprop, so must # multiply to get B's on the same scale as s[var].B2 = Covariance(B2*dsize_op, var, "B2", s.Lambda.var) s.grad_new = s.correct(s.grad) s.grad_dot_grad_new_op = tf.reduce_sum(s.grad.f * s.grad_new.f) s.grad_norm_op = u.L2(s.grad.f) s.grad_new_norm_op = u.L2(s.grad_new.f) # create parameter save and parameter restore ops s.param = VarList(s.model.trainable_vars) s.param_copy = s.param.copy() s.param_save_op = s.param_copy.assign(s.param) s.param_restore_op = s.param.assign(s.param_copy) s.param_update_op = s.param.sub(s.grad_new.cached, weight=s.lr) assert s.param.vars_ == s.grad_new.vars_ s.sess = u.get_default_session()
def advance_batch(): print("Step for model(%s) is %s"%(model.name, u.eval(model.step))) sess = u.get_default_session() # TODO: get rid of _sampled_labels sessrun([sampled_labels.initializer, _sampled_labels.initializer]) if args.advance_batch: with u.timeit("advance_batch"): sessrun(advance_batch_op) sessrun(advance_step_op)
def initialize_global_vars(verbose=False, reinitialize=False): """If reinitialize is false, will not reinitialize variables already initialized.""" sess = u.get_default_session() if not reinitialize: uninited = sessrun(global_init_query_ops) # use numpy boolean indexing to select list of initializers to run to_initialize = list(np.asarray(global_init_ops)[uninited]) else: to_initialize = global_init_ops if verbose: print("Initializing following:") for v in to_initialize: print(" " + v.name) sessrun(to_initialize, feed_dict=init_dict)
def update(self): """Upgrade cached gradients using from their live values using default session.""" sess = u.get_default_session() # sess.run(self.update_op) u.run(self.update_op)
def initialize_local_vars(): sess = u.get_default_session() sessrun(_sampled_labels.initializer, feed_dict=init_dict) sessrun(X.initializer, feed_dict=init_dict) sessrun(local_init_op, feed_dict=init_dict)