Пример #1
0
def sessrun(*args, **kwargs):
    sess = u.get_default_session()
    if not GLOBAL_PROFILE:
        return sess.run(*args, **kwargs)

    run_metadata = tf.RunMetadata()

    kwargs['options'] = full_trace_options
    kwargs['run_metadata'] = run_metadata
    result = sess.run(*args, **kwargs)
    first_entry = args[0]
    if isinstance(first_entry, list):
        if len(first_entry) == 0 and len(args) == 1:
            return None
        first_entry = first_entry[0]
    name = first_entry.name
    name = name.replace('/', '-')

    tl = timeline.Timeline(run_metadata.step_stats)
    ctf = tl.generate_chrome_trace_format()
    with open('timelines/%s.json' % (name, ), 'w') as f:
        f.write(ctf)
    with open('timelines/%s.pbtxt' % (name, ), 'w') as f:
        f.write(str(run_metadata))
    return result
Пример #2
0
def sessrun(*args, **kwargs):
  sess = u.get_default_session()
  if not GLOBAL_PROFILE:
    return sess.run(*args, **kwargs)
  
  run_metadata = tf.RunMetadata()

  kwargs['options'] = full_trace_options
  kwargs['run_metadata'] = run_metadata
  result = sess.run(*args, **kwargs)
  first_entry = args[0]
  if isinstance(first_entry, list):
    if len(first_entry) == 0 and len(args) == 1:
      return None
    first_entry = first_entry[0]
  name = first_entry.name
  name = name.replace('/', '-')

  tl = timeline.Timeline(run_metadata.step_stats)
  ctf = tl.generate_chrome_trace_format()
  with open('timelines/%s.json'%(name,), 'w') as f:
    f.write(ctf)
  with open('timelines/%s.pbtxt'%(name,), 'w') as f:
    f.write(str(run_metadata))
  return result
Пример #3
0
 def advance_batch():
     #    print("Step for model(%s) is %s"%(model.name, u.eval(model.step)))
     sess = u.get_default_session()
     # TODO: get rid of _sampled_labels
     sessrun([sampled_labels.initializer, _sampled_labels.initializer])
     if args.advance_batch:
         sessrun(advance_batch_op)
     sessrun(advance_step_op)
Пример #4
0
  def __init__(self, model_creator, dsize):
    """
    model_creator: function that creates Model object
    stats_dsize is used to control the batch size of model used to collect
      kfac statistics."""
    
    s = self       # use for private members, ie, s.some_internal_val

    s.lock = threading.Lock()    # most broad lock covering all corrections
    
    s.model = model_creator(dsize, name="kfac")
    if do_early_init:
      s.model.initialize_local_vars()
      s.model.initialize_global_vars()
      
    s.log = OrderedDict()

    # regular gradient
    s.grad = IndexedGrad(loss=s.model.loss, vars_=s.model.trainable_vars)
    
    # gradient with synthetic backprops
    s.grad2 = IndexedGrad(loss=s.model.loss2, vars_=s.model.trainable_vars)

    s.lr = VarStruct(initial_value=-np.inf, name="lr", dtype=dtype)
    s.Lambda = VarStruct(initial_value=-np.inf, name="Lambda", dtype=dtype)
    
    # covariance and SVD ops for all correctable ops, mapped to parameter
    # variable to correct
    s.kfac_correction_dict = OrderedDict()
    
    for var in s.model.trainable_vars:
      if not s.needs_correction(var):
        continue
      A = s.extract_A(s.grad2, var)
      B2 = s.extract_B(s.grad2, var)  # todo: change to extract_B
      s.register_correction(var)
      dsize_op = tf.constant(dsize, dtype=dtype)
      s[var].A = Covariance(A, var, "A", s.Lambda.var)
      # dsize is already incorporated as part of backprop, so must
      # multiply to get B's on the same scale as 
      s[var].B2 = Covariance(B2*dsize_op, var, "B2", s.Lambda.var)

      
    s.grad_new = s.correct(s.grad)
    s.grad_dot_grad_new_op = tf.reduce_sum(s.grad.f * s.grad_new.f)
    s.grad_norm_op = u.L2(s.grad.f)
    s.grad_new_norm_op = u.L2(s.grad_new.f)

    # create parameter save and parameter restore ops
    s.param = VarList(s.model.trainable_vars)
    s.param_copy = s.param.copy()
    s.param_save_op = s.param_copy.assign(s.param)
    s.param_restore_op = s.param.assign(s.param_copy)
    s.param_update_op = s.param.sub(s.grad_new.cached, weight=s.lr)
    assert s.param.vars_ == s.grad_new.vars_
    
    s.sess = u.get_default_session()
Пример #5
0
 def advance_batch():
   print("Step for model(%s) is %s"%(model.name, u.eval(model.step)))
   sess = u.get_default_session()
   # TODO: get rid of _sampled_labels
   sessrun([sampled_labels.initializer, _sampled_labels.initializer])
   if args.advance_batch:
     with u.timeit("advance_batch"):
       sessrun(advance_batch_op)
   sessrun(advance_step_op)
Пример #6
0
    def initialize_global_vars(verbose=False, reinitialize=False):
        """If reinitialize is false, will not reinitialize variables already
    initialized."""

        sess = u.get_default_session()
        if not reinitialize:
            uninited = sessrun(global_init_query_ops)
            # use numpy boolean indexing to select list of initializers to run
            to_initialize = list(np.asarray(global_init_ops)[uninited])
        else:
            to_initialize = global_init_ops

        if verbose:
            print("Initializing following:")
            for v in to_initialize:
                print("   " + v.name)

        sessrun(to_initialize, feed_dict=init_dict)
Пример #7
0
  def initialize_global_vars(verbose=False, reinitialize=False):
    """If reinitialize is false, will not reinitialize variables already
    initialized."""
    
    sess = u.get_default_session()
    if not reinitialize:
      uninited = sessrun(global_init_query_ops)
      # use numpy boolean indexing to select list of initializers to run
      to_initialize = list(np.asarray(global_init_ops)[uninited])
    else:
      to_initialize = global_init_ops
      
    if verbose:
      print("Initializing following:")
      for v in to_initialize:
        print("   " + v.name)

    sessrun(to_initialize, feed_dict=init_dict)
Пример #8
0
 def update(self):
   """Upgrade cached gradients using from their live values using
   default session."""
   sess = u.get_default_session()
   #    sess.run(self.update_op)
   u.run(self.update_op)
Пример #9
0
 def initialize_local_vars():
     sess = u.get_default_session()
     sessrun(_sampled_labels.initializer, feed_dict=init_dict)
     sessrun(X.initializer, feed_dict=init_dict)
     sessrun(local_init_op, feed_dict=init_dict)
Пример #10
0
 def initialize_local_vars():
   sess = u.get_default_session()
   sessrun(_sampled_labels.initializer, feed_dict=init_dict)
   sessrun(X.initializer, feed_dict=init_dict)
   sessrun(local_init_op, feed_dict=init_dict)