Exemple #1
0
def train(**kwargs):
  """"""
  
  # Get the special arguments
  load = kwargs.pop('load')
  force = kwargs.pop('force')
  noscreen = kwargs.pop('noscreen')
  save_dir = kwargs.pop('save_dir')
  save_metadir = kwargs.pop('save_metadir')
  network_class = kwargs.pop('network_class')
  config_file = kwargs.pop('config_file')
  
  # Get the cl-defined options
  kwargs = {key: value for key, value in six.iteritems(kwargs) if value is not None}
  for section, values in six.iteritems(kwargs):
    if section in section_names:
      values = [value.split('=', 1) for value in values]
      kwargs[section] = {opt: value for opt, value in values}
  if 'DEFAULT' not in kwargs:
    kwargs['DEFAULT'] = {}
  kwargs['DEFAULT']['network_class'] = network_class
    
  # Figure out the save_directory
  if save_metadir is not None:
    kwargs['DEFAULT']['save_metadir'] = save_metadir
  if save_dir is not None:
    kwargs['DEFAULT']['save_dir'] = save_dir
  config = Config(config_file=config_file, **kwargs)
  save_dir = config.get('DEFAULT', 'save_dir')
  
  # If not loading, ask the user if they want to overwrite the directory
  if not load and os.path.isdir(save_dir):
    if not force:
      input_str = ''
      while input_str not in ('y', 'n', 'yes', 'no'):
        input_str = input('{} already exists. It will be deleted if you continue. Do you want to proceed? [Y/n] '.format(save_dir)).lower()
      if input_str in ('n', 'no'):
        print()
        sys.exit(0)
    elif noscreen:
      sys.exit(0)
    shutil.rmtree(save_dir)
  
  # If the save_dir wasn't overwritten, load its config_file
  if os.path.isdir(save_dir):
    config_file = os.path.join(save_dir, 'config.cfg')
  else:
    os.makedirs(save_dir)
    os.system('git rev-parse HEAD >> {}'.format(os.path.join(save_dir, 'HEAD')))
  
  network_list = config.get(network_class, 'input_network_classes')
  if not load:
    with open(os.path.join(save_dir, 'config.cfg'), 'w') as f:
      config.write(f)
  input_networks, networks = resolve_network_dependencies(config, network_class, network_list, {})
  NetworkClass = getattr(parser, network_class)
  network = NetworkClass(input_networks=input_networks, config=config)
  network.train(load=load, noscreen=noscreen)
  return
Exemple #2
0
def run(**kwargs):
  """"""
  #pdb.set_trace()
  # Get the special arguments
  save_dir = kwargs.pop('save_dir')
  save_metadir = kwargs.pop('save_metadir')
  conllu_files = kwargs.pop('conllu_files')
  output_dir = kwargs.pop('output_dir')
  output_filename = kwargs.pop('output_filename')
  testing = kwargs.pop('testing')
  debug = kwargs.pop('debug')
  nornn = kwargs.pop('nornn')
  check_iter = kwargs.pop('check_iter')
  gen_tree = kwargs.pop('gen_tree')
  get_argmax = kwargs.pop('get_argmax')
  # Get the cl-defined options
  kwargs = {key: value for key, value in six.iteritems(kwargs) if value is not None}
  for section, values in six.iteritems(kwargs):
    if section in section_names:
      values = [value.split('=', 1) for value in values]
      kwargs[section] = {opt: value for opt, value in values}
  if 'DEFAULT' not in kwargs:
    kwargs['DEFAULT'] = {}
    
  # Figure out the save_directory
  if save_metadir is not None:
    kwargs['DEFAULT']['save_metadir'] = save_metadir
  if save_dir is None:
    save_dir = Config(**kwargs).get('DEFAULT', 'save_dir')
  config_file = os.path.join(save_dir, 'config.cfg')
  namelist=save_dir.split('GraphParserNetwork')
  kwargs['DEFAULT']['save_dir'] = save_dir
  #kwargs['DEFAULT']['save_dir'] = namelist[0]+'GraphParserNetwork'

  if testing:
    #kwargs['DEFAULT']['AUTO_dir'] = 'True'
    #kwargs['DEFAULT']['modelname'] = namelist[-1]
    kwargs['CoNLLUDataset']={}
    if debug:
      kwargs['CoNLLUDataset']['batch_size'] = 1000
      kwargs['CoNLLUDataset']['max_buckets'] = 30
    else:
      kwargs['CoNLLUDataset']['batch_size'] = 1000
  config = Config(defaults_file='', config_file=config_file, **kwargs)
  with open('debug.cfg', 'w') as f:
    config.write(f)
  network_class = config.get('DEFAULT', 'network_class')
  network_list = config.get(network_class, 'input_network_classes')
  input_networks, networks = resolve_network_dependencies(config, network_class, network_list, {})
  NetworkClass = getattr(parser, network_class)
  network = NetworkClass(input_networks=input_networks, config=config)
  network.parse(conllu_files, output_dir=output_dir, output_filename=output_filename, testing=testing, debug=debug,nornn=nornn, check_iter=check_iter, gen_tree=gen_tree, get_argmax=get_argmax)
  return
Exemple #3
0
def run(**kwargs):
    """"""

    # Get the special arguments
    save_dir = kwargs.pop('save_dir')
    save_metadir = kwargs.pop('save_metadir')
    conllu_files = kwargs.pop('conllu_files')
    output_dir = kwargs.pop('output_dir')
    output_filename = kwargs.pop('output_filename')

    # Get the cl-defined options
    kwargs = {
        key: value
        for key, value in six.iteritems(kwargs) if value is not None
    }
    for section, values in six.iteritems(kwargs):
        if section in section_names:
            values = [value.split('=', 1) for value in values]
            kwargs[section] = {opt: value for opt, value in values}
    if 'DEFAULT' not in kwargs:
        kwargs['DEFAULT'] = {}

    # Figure out the save_directory
    if save_metadir is not None:
        kwargs['DEFAULT']['save_metadir'] = save_metadir
    if save_dir is None:
        save_dir = Config(**kwargs).get('DEFAULT', 'save_dir')
    config_file = os.path.join(save_dir, 'config.cfg')
    kwargs['DEFAULT']['save_dir'] = save_dir

    config = Config(defaults_file='', config_file=config_file, **kwargs)
    with open('debug.cfg', 'w') as f:
        config.write(f)
    network_class = config.get('DEFAULT', 'network_class')
    network_list = config.get(network_class, 'input_network_classes')
    input_networks, networks = resolve_network_dependencies(
        config, network_class, network_list, {})
    NetworkClass = getattr(parser, network_class)
    network = NetworkClass(input_networks=input_networks, config=config)
    network.parse(conllu_files,
                  output_dir=output_dir,
                  output_filename=output_filename)
    return
Exemple #4
0
def resolve_network_dependencies(config, network_class, network_list, networks):
  if network_list in ('None', ''):
    return set(), networks
  else:
    network_list = network_list.split(':')
    if network_class not in networks:
      for _network_class in network_list:
        config_file = os.path.join(config.get('DEFAULT', _network_class + '_dir'), 'config.cfg')
        _config = Config(config_file=config_file)
        _network_list = _config.get(_network_class, 'input_network_classes')
        input_networks, networks = resolve_network_dependencies(_config, _network_class, _network_list, networks)
        NetworkClass = getattr(parser, _network_class)
        networks[_network_class] = NetworkClass(input_networks=input_networks, config=config)
    return set(networks[_network_class] for _network_class in network_list), networks
Exemple #5
0
def hpo(**kwargs):
  """"""
  
  # Get the special arguments
  noscreen = kwargs.pop('noscreen')
  save_dir = kwargs.pop('save_dir')
  save_metadir = kwargs.pop('save_metadir')
  network_class = kwargs.pop('network_class')
  config_file = kwargs.pop('config_file')
  rand_file = kwargs.pop('rand_file')
  eval_metric = kwargs.pop('eval_metric')
  
  # Get the cl-defined options
  kwargs = {key: value for key, value in six.iteritems(kwargs) if value is not None}
  for section, values in six.iteritems(kwargs):
    if section in section_names:
      values = [value.split('=', 1) for value in values]
      kwargs[section] = {opt: value for opt, value in values}
  if 'DEFAULT' not in kwargs:
    kwargs['DEFAULT'] = {}
  kwargs['DEFAULT']['network_class'] = network_class
    
  # Figure out the save_directory
  if save_metadir is not None:
    kwargs['DEFAULT']['save_metadir'] = save_metadir
  if save_dir is None:
    save_dir = Config(**kwargs).get('DEFAULT', 'save_dir')
  if not os.path.exists(save_dir):
    os.makedirs(save_dir)
  
  # Get the randomly generated options and possibly add them
  #-------------------------------------------------------------
  lang = kwargs['DEFAULT']['LANG']
  treebank = kwargs['DEFAULT']['TREEBANK']
  lc = kwargs['DEFAULT']['LC']
  tb = kwargs['DEFAULT']['TB']
  base = 'data/CoNLL18/UD_{}-{}/{}_{}-ud-dev.conllu'.format(lang, treebank, lc, tb)
  def eval_func(save_dir):
    return evaluate(base, os.path.join(save_dir, 'parsed', base), eval_metric)
  #-------------------------------------------------------------
  rargs = next(MVGHPO(rand_file, save_dir, eval_func=eval_func))
  for section in rargs:
    if section not in kwargs:
      kwargs[section] = rargs[section]
    else:
      for option, value in six.iteritems(rargs[section]):
        if option not in kwargs[section]:
          kwargs[section][option] = value
  save_dir = os.path.join(save_dir, str(int(time.time()*100000)))
  
  # If not loading, ask the user if they want to overwrite the directory
  if os.path.isdir(save_dir):
    print()
    sys.exit(0)
  else:
    os.mkdir(save_dir)
    os.system('git rev-parse HEAD >> {}'.format(os.path.join(save_dir, 'HEAD')))
  
  kwargs['DEFAULT']['save_dir'] = save_dir
  config = Config(config_file=config_file, **kwargs)
  network_list = config.get(network_class, 'input_network_classes')
  with open(os.path.join(save_dir, 'config.cfg'), 'w') as f:
    config.write(f)
  input_networks, networks = resolve_network_dependencies(config, network_class, network_list, {})
  NetworkClass = getattr(parser, network_class)
  network = NetworkClass(input_networks=input_networks, config=config)
  network.train(noscreen=noscreen)
  return