def score(s, produced):
  global counter
  global cache
  #if counter == 20:
  #  exit()
  print "score called"
  print s
  counter += 1
  subexps = sexp.subexps(s)
  #if s in subexps:
  #  subexps.remove(s)
  #print subexps
  splits = sum([sexp.split(s, subexp) for subexp in subexps], [])
  #print splits
  #for split in splits:
  #  print split
  #print
  if not splits:
    #print "hit bottom"
    #print
    #print 1
    return 1
  #print
  print sexp.pretty_lambda(s)
  for x in splits:
    print "  ", sexp.pretty_lambda(x[0]), "/", sexp.pretty_lambda(x[1])
  print
  scr = max([score_one_split(s, x[0], x[1], produced) for x in splits])
  #print scr
  return scr
def random_derivation(s, depth = 0, category = 'S', productions_above = []):
  global lexicon

  #for i in range(depth):
  #  print ' ',
  key = sexp.pretty_lambda(s)
  #print key

  if key in lexicon:
    #print lexicon[key]
    options = [l for l in lexicon[key] if l[0].replace('/', '|').replace('\\',
    '|') == category]
    if len(options) == 0:
      return False
    picked = random.sample(options, 1)[0]
    return [key, picked[0], picked[1]]

  if sexp.totally_vacuous(s):
    return False

  if depth > 3:
    return False

  splits = sum([sexp.split(s, sub) for sub in sexp.subexps(s)], [])
  if not splits:
    return False
  random.shuffle(splits)
  for split in splits:
    #print sexp.pretty_lambda(split[0]) + " : " + sexp.pretty_lambda(s[1])
    #print
    #print '\n'.join([sexp.pretty_lambda(s[0]) + " : " + sexp.pretty_lambda(s[1]) for s in productions_above])
    #print
    #print split in productions_above
    #print
    #print


    f = split[0]
    g = split[1]
    fcat = catf(f)
    if '|' in fcat:
      fcat2 = '(%s)' % fcat
    else:
      fcat2 = fcat
    gcat = '%s|%s' % (category, fcat2)
    #print category, fcat, gcat

    #print f, g
    #exit()

    d1 = random_derivation(split[0], depth+1, fcat, productions_above + [split])
    d2 = random_derivation(split[1], depth+1, gcat, productions_above + [split])
    if d1 and d2:
      return [key, category, d1, d2]
  return False
def best_derivation(sent, category, cky=None, depth=0):

  global counter
  global cache
  global lexicon

  if cky == None:
    cky = []

  lkey = sexp.pretty_lambda(sent)
  key = lkey + ' ' + category

  #if key in cache:
  #  return cache[key]

  counter += 1

  if lkey in lexicon:
    terminals = all_lex_entries(lkey, category)
    scored = [(terminal, lm_score(terminal, cky)) for terminal in terminals]
    if terminals:
      r = {'key': key,
          'scored': scored}
    #terminal = choose_lex_entry(lkey, category)
    #if terminal:
    #  r = {'key': key,
    #       'score': 1,
    #       'terminal': terminal}
    else:
      r = False
    cache[key] = r
    return r

  if sexp.totally_vacuous(sent):
    r = False
    cache[key] = r
    return r

  if depth == 3:
    r = False
    return r

  subs = sexp.subexps(sent)
  splits = sum((sexp.split(sent, sub) for sub in subs), [])

  scores = []
  for split in splits:
    ncky = list(cky)
    (fcat, gcat) = make_categories(split, category)
    left = best_derivation(split[0], fcat, ncky, depth+1)
    if not left:
      continue
    right = best_derivation(split[1], gcat, ncky, depth+1)
    if not right:
      continue
    sc = left['score'] + right['score'] + split_potential(sent, split)
    scores.append({'key': key,
                   'score': sc,
                   'left': left,
                   'right': right})

  if not scores:
    return False
  r = max(scores, key=lambda x: x['score'])
  cache[key] = r
  return r