def _feed_hyper_params_to_tb(self, metric_dict=None): if "tensorboard" not in self.private_config["LOG_USED"]: logger.info("skip feeding hyper-param to tb") return for fmt in logger.Logger.CURRENT.output_formats: if isinstance(fmt, logger.TensorBoardOutputFormat): fmt.add_hyper_params_to_tb(self.hyper_param, metric_dict)
def configure(self, task_name, private_config_path, log_root=None, data_root=None, ignore_file_path=None, run_file=None): fs = open(private_config_path, encoding="UTF-8") try: self.private_config = yaml.load(fs) except TypeError: self.private_config = yaml.safe_load(fs) self.run_file = run_file self.ignore_file_path = ignore_file_path self.task_name = task_name if log_root is not None: self.data_root = log_root else: self.data_root = data_root logger.info("private_config: ") self.dl_framework = self.private_config["DL_FRAMEWORK"] self.project_root = "/".join(private_config_path.split("/")[:-1]) for k, v in self.private_config.items(): logger.info("k: {}, v: {}".format(k, v))
def new_saver(self, max_to_keep, var_prefix=None): """ initialize new tf.Saver :param var_prefix: we use var_prefix to filter the variables for saving. :param max_to_keep: :return: """ if self.dl_framework == FRAMEWORK.tensorflow: import tensorflow as tf if var_prefix is None: var_prefix = '' var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix) logger.info("save variable :") for v in var_list: logger.info(v) self.saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=self.checkpoint_dir, save_relative_paths=True) elif self.dl_framework == FRAMEWORK.torch: self.max_to_keep = max_to_keep self.checkpoint_keep_list = [] else: raise NotImplementedError
def log_files_gen(self): info = None self.record_date = datetime.datetime.now() logger.info("gen log files for record date : {}".format( self.record_date)) if info is None: info = self.auto_parse_info() info = '&' + info self.info = info code_dir, _ = self.__create_file_directory(osp.join( self.data_root, CODE, self.task_name), '', is_file=False) log_dir, _ = self.__create_file_directory(osp.join( self.data_root, LOG, self.task_name), '', is_file=False) self.pkl_dir, self.pkl_file = self.__create_file_directory( osp.join(self.data_root, ARCHIVE_TESTER, self.task_name), '.pkl') self.checkpoint_dir, _ = self.__create_file_directory(osp.join( self.data_root, CHECKPOINT, self.task_name), is_file=False) self.results_dir, _ = self.__create_file_directory(osp.join( self.data_root, OTHER_RESULTS, self.task_name), is_file=False) self.log_dir = log_dir self.code_dir = code_dir self._init_logger() self.serialize_object_and_save() self.__copy_source_code(self.run_file, code_dir) self._feed_hyper_params_to_tb() self.print_log_dir()
def log_file_finder(cls, record_date, task_name='train', file_root='../checkpoint/', log_type='dir'): record_date = datetime.datetime.strptime(record_date, '%Y/%m/%d/%H-%M-%S-%f') prefix = osp.join(file_root, task_name) directory = str(record_date.strftime("%Y/%m/%d")) directory = osp.join(prefix, directory) file_found = '' for root, dirs, files in os.walk(directory): if log_type == 'dir': search_list = dirs elif log_type == 'files': search_list = files else: raise NotImplementedError for search_item in search_list: if search_item.startswith( str(record_date.strftime("%H-%M-%S-%f"))): split_dir = search_item.split(' ') # self.__ipaddr = split_dir[1] info = " ".join(split_dir[2:]) logger.info( "load data: \n ts {}, \n ip {}, \n info {}".format( split_dir[0], split_dir[1], info)) file_found = search_item break return directory, file_found
def time_record_end(name): end_time = time.time() start_time = rc_start_time[name] logger.record_tabular("time_used/{}".format(name), end_time - start_time) logger.info("[test] func {0} time used {1:.2f}".format( name, end_time - start_time)) del rc_start_time[name]
def load_checkpoint(saver, tester): # TODO: load with variable scope. import tensorflow as tf logger.info("load checkpoint {}".format(tester.checkpoint_dir)) ckpt_path = tf.train.latest_checkpoint(tester.checkpoint_dir) saver.restore(tf.get_default_session(), ckpt_path) max_iter = ckpt_path.split('-')[-1] tester.time_step_holder.set_time(max_iter) return int(max_iter)
def time_used_wrap(name, func, *args, **kwargs): start_time = time.time() output = func(*args, **kwargs) end_time = time.time() time_used = end_time - start_time logger.info("[test] func {0} time used {1:.2f}".format(name, time_used)) logger.record_tabular("time_used/{}".format(name), time_used) logger.dump_tabular() return output
def __init__(self, sftp_server, username, password, ignore=None): self.sftp_server = sftp_server self.username = username self.password = password self.sftp = self.sftpconnect() logger.info("login success.") self.ignore = ignore self.ignore_rules = [] if self.ignore is not None: self.__init_gitignore()
def _init_logger(self): self.writer = None # logger configure logger.info("store file %s" % self.pkl_file) logger.configure(self.log_dir, self.private_config["LOG_USED"], framework=self.private_config["DL_FRAMEWORK"]) for fmt in logger.Logger.CURRENT.output_formats: if isinstance(fmt, logger.TensorBoardOutputFormat): self.writer = fmt.writer if "tensorboard" not in self.private_config["LOG_USED"]: time_step_holder.config(0, 0, tf_log=False)
def load_tester(cls, record_date, task_name, log_root): logger.info("load tester") res_dir, res_file = cls.log_file_finder(record_date, task_name=task_name, file_root=osp.join( log_root, ARCHIVE_TESTER), log_type='files') import dill load_tester = dill.load(open(osp.join(res_dir, res_file), 'rb')) assert isinstance(load_tester, Tester) logger.info("update log files' root") load_tester.update_log_files_location(root=log_root) return load_tester
def load_checkpoint(self): if self.dl_framework == FRAMEWORK.tensorflow: # TODO: load with variable scope. import tensorflow as tf cpt_name = osp.join(self.checkpoint_dir) logger.info("load checkpoint {}".format(cpt_name)) ckpt_path = tf.train.latest_checkpoint(cpt_name) self.saver.restore(tf.get_default_session(), ckpt_path) max_iter = ckpt_path.split('-')[-1] self.time_step_holder.set_time(max_iter) return int(max_iter), None elif self.dl_framework == FRAMEWORK.torch: import torch return self.checkpoint_keep_list[-1], torch.load( tester.checkpoint_dir + "checkpoint-{}.pt".format(self.checkpoint_keep_list[-1]))
def add_summary_to_logger(self, summary, name='', simple_val=False, freq=20): """ [deprecated] see RLA.logger.log_from_tf_summary """ logger.warn( "add_summary_to_logger is deprecated. See RLA.logger.log_from_tf_summary." ) if "tensorboard" not in self.private_config["LOG_USED"]: logger.info("skip adding summary to tb") return if name not in self.summary_add_dict: self.summary_add_dict[name] = [] if freq > 0: summary_ts = int(self.time_step_holder.get_time() / freq) else: summary_ts = 0 if freq <= 0 or summary_ts not in self.summary_add_dict[name]: from tensorflow.core.framework import summary_pb2 summ = summary_pb2.Summary() summ.ParseFromString(summary) if simple_val: list_field = summ.ListFields() def recursion_util(inp_field): if hasattr(inp_field, "__getitem__"): for inp in inp_field: recursion_util(inp) elif hasattr(inp_field, 'simple_value'): logger.record_tabular(name + '/' + inp_field.tag, inp_field.simple_value) else: pass recursion_util(list_field) logger.dump_tabular() else: self.writer.add_summary(summary, self.time_step_holder.get_time()) self.writer.flush() self.summary_add_dict[name].append(summary_ts)
def download_file(self, remote_file, local_file): self.sftp = self.sftpconnect() logger.info("try download {}".format(local_file)) if not os.path.isfile(local_file): logger.info("new file {}".format(local_file)) self.sftp.get(remote_file, local_file) elif self.sftp.stat(remote_file).st_size != os.path.getsize( local_file): logger.info("update file {}".format(local_file)) self.sftp.get(remote_file, local_file) else: logger.info("skip download file {}".format(remote_file))
def new_saver(max_to_keep, var_prefix, checkpoint_path=None): """ initialize new tf.Saver :param var_prefix: we use var_prefix to filter the variables for saving. :param max_to_keep: :return: """ import tensorflow as tf var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, var_prefix) logger.info("save variable :") for v in var_list: logger.info(v) if checkpoint_path: checkpoint_path = tester.checkpoint_path saver = tf.train.Saver(var_list=var_list, max_to_keep=max_to_keep, filename=checkpoint_path, save_relative_paths=True) return saver
def download_file(self, remote_file, local_file): bufsize = 1024 logger.info("try download {}".format(local_file)) if not os.path.isfile(local_file): fp = open(local_file, 'wb') logger.info("new file {}".format(local_file)) self.ftp.retrbinary('RETR ' + remote_file, fp.write, bufsize) elif self.ftp.size(remote_file) != os.path.getsize(local_file): fp = open(local_file, 'wb') logger.info("update file {}".format(local_file)) self.ftp.retrbinary('RETR ' + remote_file, fp.write, bufsize) else: logger.info("skip download file {}".format(remote_file))
def print_large_memory_variable(self): import sys large_mermory_dict = {} def sizeof_fmt(num, suffix='B'): for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: if abs(num) < 1024.0: return "%3.1f %s%s" % (num, unit, suffix), unit num /= 1024.0 return "%.1f %s%s" % (num, 'Yi', suffix), 'Yi' for name, size in sorted( ((name, sys.getsizeof(value)) for name, value in locals().items()), key=lambda x: -x[1])[:10]: size_str, fmt_type = sizeof_fmt(size) if fmt_type in ['', 'Ki', 'Mi']: continue logger.info("{:>30}: {:>8}".format(name, size_str)) large_mermory_dict[str(name)] = size_str if large_mermory_dict != {}: summary = self.dict_to_table_text_summary(large_mermory_dict, 'large_memory') self.add_summary_to_logger(summary, 'large_memory')
def save_checkpoint(self, model_dict: Optional[dict] = None, related_variable: Optional[dict] = None): if self.dl_framework == FRAMEWORK.tensorflow: import tensorflow as tf iter = self.time_step_holder.get_time() cpt_name = osp.join(self.checkpoint_dir, 'checkpoint') logger.info("save checkpoint to ", cpt_name, iter) self.saver.save(tf.get_default_session(), cpt_name, global_step=iter) elif self.dl_framework == FRAMEWORK.torch: import torch iter = self.time_step_holder.get_time() torch.save(model_dict, f=tester.checkpoint_dir + "checkpoint-{}.pt".format(iter)) self.checkpoint_keep_list.append(iter) if len(self.checkpoint_keep_list) > self.max_to_keep: for i in range( len(self.checkpoint_keep_list) - self.max_to_keep): rm_ckp_name = tester.checkpoint_dir + "checkpoint-{}.pt".format( self.checkpoint_keep_list[i]) logger.info("rm the older checkpoint", rm_ckp_name) os.remove(rm_ckp_name) self.checkpoint_keep_list = self.checkpoint_keep_list[ -1 * self.max_to_keep:] else: raise NotImplementedError if related_variable is not None: for k, v in related_variable.items(): self.add_custom_data(k, v, type(v), mode='replace') self.add_custom_data(DEFAULT_X_NAME, time_step_holder.get_time(), int, mode='replace')
def print_log_dir(self): logger.info("log dir: {}".format(self.log_dir)) logger.info("pkl_file: {}".format(self.pkl_file)) logger.info("checkpoint_dir: {}".format(self.checkpoint_dir)) logger.info("results_dir: {}".format(self.results_dir))