def train(**kwargs): """""" # Get the special arguments load = kwargs.pop('load') force = kwargs.pop('force') noscreen = kwargs.pop('noscreen') save_dir = kwargs.pop('save_dir') save_metadir = kwargs.pop('save_metadir') network_class = kwargs.pop('network_class') config_file = kwargs.pop('config_file') # Get the cl-defined options kwargs = {key: value for key, value in six.iteritems(kwargs) if value is not None} for section, values in six.iteritems(kwargs): if section in section_names: values = [value.split('=', 1) for value in values] kwargs[section] = {opt: value for opt, value in values} if 'DEFAULT' not in kwargs: kwargs['DEFAULT'] = {} kwargs['DEFAULT']['network_class'] = network_class # Figure out the save_directory if save_metadir is not None: kwargs['DEFAULT']['save_metadir'] = save_metadir if save_dir is not None: kwargs['DEFAULT']['save_dir'] = save_dir config = Config(config_file=config_file, **kwargs) save_dir = config.get('DEFAULT', 'save_dir') # If not loading, ask the user if they want to overwrite the directory if not load and os.path.isdir(save_dir): if not force: input_str = '' while input_str not in ('y', 'n', 'yes', 'no'): input_str = input('{} already exists. It will be deleted if you continue. Do you want to proceed? [Y/n] '.format(save_dir)).lower() if input_str in ('n', 'no'): print() sys.exit(0) elif noscreen: sys.exit(0) shutil.rmtree(save_dir) # If the save_dir wasn't overwritten, load its config_file if os.path.isdir(save_dir): config_file = os.path.join(save_dir, 'config.cfg') else: os.makedirs(save_dir) os.system('git rev-parse HEAD >> {}'.format(os.path.join(save_dir, 'HEAD'))) network_list = config.get(network_class, 'input_network_classes') if not load: with open(os.path.join(save_dir, 'config.cfg'), 'w') as f: config.write(f) input_networks, networks = resolve_network_dependencies(config, network_class, network_list, {}) NetworkClass = getattr(parser, network_class) network = NetworkClass(input_networks=input_networks, config=config) network.train(load=load, noscreen=noscreen) return
def run(**kwargs): """""" #pdb.set_trace() # Get the special arguments save_dir = kwargs.pop('save_dir') save_metadir = kwargs.pop('save_metadir') conllu_files = kwargs.pop('conllu_files') output_dir = kwargs.pop('output_dir') output_filename = kwargs.pop('output_filename') testing = kwargs.pop('testing') debug = kwargs.pop('debug') nornn = kwargs.pop('nornn') check_iter = kwargs.pop('check_iter') gen_tree = kwargs.pop('gen_tree') get_argmax = kwargs.pop('get_argmax') # Get the cl-defined options kwargs = {key: value for key, value in six.iteritems(kwargs) if value is not None} for section, values in six.iteritems(kwargs): if section in section_names: values = [value.split('=', 1) for value in values] kwargs[section] = {opt: value for opt, value in values} if 'DEFAULT' not in kwargs: kwargs['DEFAULT'] = {} # Figure out the save_directory if save_metadir is not None: kwargs['DEFAULT']['save_metadir'] = save_metadir if save_dir is None: save_dir = Config(**kwargs).get('DEFAULT', 'save_dir') config_file = os.path.join(save_dir, 'config.cfg') namelist=save_dir.split('GraphParserNetwork') kwargs['DEFAULT']['save_dir'] = save_dir #kwargs['DEFAULT']['save_dir'] = namelist[0]+'GraphParserNetwork' if testing: #kwargs['DEFAULT']['AUTO_dir'] = 'True' #kwargs['DEFAULT']['modelname'] = namelist[-1] kwargs['CoNLLUDataset']={} if debug: kwargs['CoNLLUDataset']['batch_size'] = 1000 kwargs['CoNLLUDataset']['max_buckets'] = 30 else: kwargs['CoNLLUDataset']['batch_size'] = 1000 config = Config(defaults_file='', config_file=config_file, **kwargs) with open('debug.cfg', 'w') as f: config.write(f) network_class = config.get('DEFAULT', 'network_class') network_list = config.get(network_class, 'input_network_classes') input_networks, networks = resolve_network_dependencies(config, network_class, network_list, {}) NetworkClass = getattr(parser, network_class) network = NetworkClass(input_networks=input_networks, config=config) network.parse(conllu_files, output_dir=output_dir, output_filename=output_filename, testing=testing, debug=debug,nornn=nornn, check_iter=check_iter, gen_tree=gen_tree, get_argmax=get_argmax) return
def run(**kwargs): """""" # Get the special arguments save_dir = kwargs.pop('save_dir') save_metadir = kwargs.pop('save_metadir') conllu_files = kwargs.pop('conllu_files') output_dir = kwargs.pop('output_dir') output_filename = kwargs.pop('output_filename') # Get the cl-defined options kwargs = { key: value for key, value in six.iteritems(kwargs) if value is not None } for section, values in six.iteritems(kwargs): if section in section_names: values = [value.split('=', 1) for value in values] kwargs[section] = {opt: value for opt, value in values} if 'DEFAULT' not in kwargs: kwargs['DEFAULT'] = {} # Figure out the save_directory if save_metadir is not None: kwargs['DEFAULT']['save_metadir'] = save_metadir if save_dir is None: save_dir = Config(**kwargs).get('DEFAULT', 'save_dir') config_file = os.path.join(save_dir, 'config.cfg') kwargs['DEFAULT']['save_dir'] = save_dir config = Config(defaults_file='', config_file=config_file, **kwargs) with open('debug.cfg', 'w') as f: config.write(f) network_class = config.get('DEFAULT', 'network_class') network_list = config.get(network_class, 'input_network_classes') input_networks, networks = resolve_network_dependencies( config, network_class, network_list, {}) NetworkClass = getattr(parser, network_class) network = NetworkClass(input_networks=input_networks, config=config) network.parse(conllu_files, output_dir=output_dir, output_filename=output_filename) return
def resolve_network_dependencies(config, network_class, network_list, networks): if network_list in ('None', ''): return set(), networks else: network_list = network_list.split(':') if network_class not in networks: for _network_class in network_list: config_file = os.path.join(config.get('DEFAULT', _network_class + '_dir'), 'config.cfg') _config = Config(config_file=config_file) _network_list = _config.get(_network_class, 'input_network_classes') input_networks, networks = resolve_network_dependencies(_config, _network_class, _network_list, networks) NetworkClass = getattr(parser, _network_class) networks[_network_class] = NetworkClass(input_networks=input_networks, config=config) return set(networks[_network_class] for _network_class in network_list), networks
def hpo(**kwargs): """""" # Get the special arguments noscreen = kwargs.pop('noscreen') save_dir = kwargs.pop('save_dir') save_metadir = kwargs.pop('save_metadir') network_class = kwargs.pop('network_class') config_file = kwargs.pop('config_file') rand_file = kwargs.pop('rand_file') eval_metric = kwargs.pop('eval_metric') # Get the cl-defined options kwargs = {key: value for key, value in six.iteritems(kwargs) if value is not None} for section, values in six.iteritems(kwargs): if section in section_names: values = [value.split('=', 1) for value in values] kwargs[section] = {opt: value for opt, value in values} if 'DEFAULT' not in kwargs: kwargs['DEFAULT'] = {} kwargs['DEFAULT']['network_class'] = network_class # Figure out the save_directory if save_metadir is not None: kwargs['DEFAULT']['save_metadir'] = save_metadir if save_dir is None: save_dir = Config(**kwargs).get('DEFAULT', 'save_dir') if not os.path.exists(save_dir): os.makedirs(save_dir) # Get the randomly generated options and possibly add them #------------------------------------------------------------- lang = kwargs['DEFAULT']['LANG'] treebank = kwargs['DEFAULT']['TREEBANK'] lc = kwargs['DEFAULT']['LC'] tb = kwargs['DEFAULT']['TB'] base = 'data/CoNLL18/UD_{}-{}/{}_{}-ud-dev.conllu'.format(lang, treebank, lc, tb) def eval_func(save_dir): return evaluate(base, os.path.join(save_dir, 'parsed', base), eval_metric) #------------------------------------------------------------- rargs = next(MVGHPO(rand_file, save_dir, eval_func=eval_func)) for section in rargs: if section not in kwargs: kwargs[section] = rargs[section] else: for option, value in six.iteritems(rargs[section]): if option not in kwargs[section]: kwargs[section][option] = value save_dir = os.path.join(save_dir, str(int(time.time()*100000))) # If not loading, ask the user if they want to overwrite the directory if os.path.isdir(save_dir): print() sys.exit(0) else: os.mkdir(save_dir) os.system('git rev-parse HEAD >> {}'.format(os.path.join(save_dir, 'HEAD'))) kwargs['DEFAULT']['save_dir'] = save_dir config = Config(config_file=config_file, **kwargs) network_list = config.get(network_class, 'input_network_classes') with open(os.path.join(save_dir, 'config.cfg'), 'w') as f: config.write(f) input_networks, networks = resolve_network_dependencies(config, network_class, network_list, {}) NetworkClass = getattr(parser, network_class) network = NetworkClass(input_networks=input_networks, config=config) network.train(noscreen=noscreen) return