def train(config, level_name=None, timestamp=None, time_budget=None, verbose_logging=None, debug=None): """ Trains a given YAML file. Parameters ---------- config : str A YAML configuration file specifying the training procedure. level_name : bool, optional Display the log level (e.g. DEBUG, INFO) for each logged message. timestamp : bool, optional Display human-readable timestamps for each logged message. time_budget : int, optional Time budget in seconds. Stop training at the end of an epoch if more than this number of seconds has elapsed. verbose_logging : bool, optional Display timestamp, log level and source logger for every logged message (implies timestamp and level_name are True). debug : bool, optional Display any DEBUG-level log messages, False by default. """ train_obj = serial.load_train_file(config) try: iter(train_obj) iterable = True except TypeError: iterable = False # Undo our custom logging setup. restore_defaults() # Set up the root logger with a custom handler that logs stdout for INFO # and DEBUG and stderr for WARNING, ERROR, CRITICAL. root_logger = logging.getLogger() if verbose_logging: formatter = logging.Formatter(fmt="%(asctime)s %(name)s %(levelname)s " "%(message)s") handler = CustomStreamHandler(formatter=formatter) else: if timestamp: prefix = '%(asctime)s ' else: prefix = '' formatter = CustomFormatter(prefix=prefix, only_from='pylearn2') handler = CustomStreamHandler(formatter=formatter) root_logger.addHandler(handler) # Set the root logger level. if debug: root_logger.setLevel(logging.DEBUG) else: root_logger.setLevel(logging.INFO) if iterable: for number, subobj in enumerate(iter(train_obj)): # Publish a variable indicating the training phase. phase_variable = 'PYLEARN2_TRAIN_PHASE' phase_value = 'phase%d' % (number + 1) os.environ[phase_variable] = phase_value # Execute this training phase. subobj.main_loop(time_budget=time_budget) # Clean up, in case there's a lot of memory used that's # necessary for the next phase. del subobj gc.collect() else: train_obj.main_loop(time_budget=time_budget)
'training procedure') return parser if __name__ == "__main__": parser = make_argument_parser() args = parser.parse_args() train_obj = serial.load_train_file(args.config) try: iter(train_obj) iterable = True except TypeError as e: iterable = False # Undo our custom logging setup. restore_defaults() # Set up the root logger with a custom handler that logs stdout for INFO # and DEBUG and stderr for WARNING, ERROR, CRITICAL. root_logger = logging.getLogger() if args.verbose_logging: formatter = logging.Formatter(fmt="%(asctime)s %(name)s %(levelname)s " "%(message)s") handler = CustomStreamHandler(formatter=formatter) else: if args.timestamp: prefix = '%(asctime)s ' else: prefix = '' formatter = CustomFormatter(prefix=prefix, only_from='pylearn2') handler = CustomStreamHandler(formatter=formatter) root_logger.addHandler(handler)
def train(config, level_name=None, timestamp=None, time_budget=None, verbose_logging=None, debug=None, environ=None, skip_exceptions=False): """ Trains a given YAML file. Parameters ---------- config : str A YAML configuration file specifying the training procedure. level_name : bool, optional Display the log level (e.g. DEBUG, INFO) for each logged message. timestamp : bool, optional Display human-readable timestamps for each logged message. time_budget : int, optional Time budget in seconds. Stop training at the end of an epoch if more than this number of seconds has elapsed. verbose_logging : bool, optional Display timestamp, log level and source logger for every logged message (implies timestamp and level_name are True). debug : bool, optional Display any DEBUG-level log messages, False by default. environ : dict, optional Custom variables to be replaced in yaml file. """ # Undo our custom logging setup. restore_defaults() # Set up the root logger with a custom handler that logs stdout for INFO # and DEBUG and stderr for WARNING, ERROR, CRITICAL. root_logger = logging.getLogger() if verbose_logging: formatter = logging.Formatter(fmt="%(asctime)s %(name)s %(levelname)s " "%(message)s") handler = CustomStreamHandler(formatter=formatter) else: if timestamp: prefix = '%(asctime)s ' else: prefix = '' if level_name: prefix = prefix + '%(levelname)s ' formatter = CustomFormatter(prefix=prefix, only_from='pylearn2') handler = CustomStreamHandler(formatter=formatter) root_logger.addHandler(handler) # Set the root logger level. if debug: root_logger.setLevel(logging.DEBUG) else: root_logger.setLevel(logging.INFO) # publish environment variables relevant to this file serial.prepare_train_file(config) # load the tree of Proxy objects yaml_tree = yaml_parse.load_path(config, instantiate=False, environ=environ) if not isinstance(yaml_tree, dict): raise ValueError('.yaml file is expected to contain a dictionary as the root object') elif not (yaml_tree.has_key('multiseq') and yaml_tree.has_key('train')): raise ValueError('.yaml file is expected to have two keys: multiseq and train') # the two important objects multiseq = yaml_tree['multiseq'] train_list = yaml_tree['train'] if not isinstance(train_list, list): train_list = [train_list] # prepare to run multiseq.first_iteration() cont_flag = True # see if we're going to generate a report repot_f = None report = _getVar('PYLEARN2_REPORT', multiseq.dynamic_env) if report: repot_f = open(report, "a") hdr = (' Index ' 'Date & Time ' 'Tag ' 'Finish Train Total Batchs Epchs Exampls Params ' 'Best result Tests in file\n') hdr_guards = '-' * len(hdr) + '\n' repot_f.write('\n') repot_f.write(hdr_guards) repot_f.write(hdr) repot_f.write(hdr_guards) repot_f.write('\n') emergency_exit = False while cont_flag: # log to user and start a new line for the report root_logger.debug('Run %s with tag %s', multiseq.dynamic_env['MULTISEQ_ITER'], multiseq.dynamic_env['MULTISEQ_TAG']) if repot_f: repot_f.write('%6s %19s %36s ' % \ (multiseq.dynamic_env['MULTISEQ_ITER'], strftime("%Y %m %d %H:%M:%S"), multiseq.dynamic_env['MULTISEQ_TAG'])) # update the environment with our dynamic variables # yaml_parse.additional_environ = multiseq.dynamic_env os.environ.update(multiseq.dynamic_env) # as this wil probably be an unattended process we may want to # tolerate exceptions try: # TODO: we are accesing a protected member here # either change the name or define a wrapper train_list_inst = yaml_parse._instantiate(train_list) # if the environment defines a PARAMETERS_FILE variable # dump the parameters there as simple text multiseq.save_params() # perform all the tests/experiments once first_subobj = True subobj_completed = 0 for number, subobj in enumerate(train_list_inst): # Publish a variable indicating the training phase. phase_variable = 'PYLEARN2_TRAIN_PHASE' phase_value = 'phase%d' % (number + 1) os.environ[phase_variable] = phase_value # Execute this training phase. try: subobj.main_loop(time_budget=time_budget) except: subobj.tear_down() raise # log first train to the report if first_subobj and repot_f: params = subobj.model.get_params() par_cnt = sum(map(lambda x: x.get_value().size, params)) repot_f.write('%6s %8d %8d %6d %6d %8d %8d %s' % \ (str(subobj.model.monitor.training_succeeded), subobj.training_seconds.get_value().item(0), subobj.total_seconds.get_value().item(0), subobj.model.monitor.get_batches_seen(), subobj.model.monitor.get_epochs_seen(), subobj.model.monitor.get_examples_seen(), par_cnt, _getBestResult(subobj))) # TODO: report performance here or channels # for best model according to objective first_subobj = False # Clean up, in case there's a lot of memory used that's # necessary for the next phase. # TODO: because subobj is part of a bigger object it may be # that it does not get cleaned up here. del subobj gc.collect() subobj_completed = subobj_completed + 1 except (KeyboardInterrupt, SystemExit): emergency_exit = True if repot_f: repot_f.write('%50s' % 'User terminated') except Exception, exc: if skip_exceptions: if repot_f: repot_f.write('%50s' % str(exc)) else: raise # we've completed a run; finalize report line for it if repot_f: repot_f.write(' %4d completed\n' % subobj_completed) repot_f.flush() # user requested exit if emergency_exit: break # LiveMonitoring seems to stay behind and binded to same port # That will throw an exception on next run # We try to mitigate that here. del train_list_inst gc.collect() # prepare next run cont_flag = multiseq.next_iteration()