def __init__(self, model, graph=None): # Each agent works on one tensorflow graph with a tensorflow session # .. set to it assert isinstance(model, tfr.models.Model) self._model = model self._session = None self._graph = None # Graph variables self._is_training = None self._init_graph(graph) # An agent saves model and writes summary self._saver = None self._summary_writer = None # An agent holds a default note self._note = Note()
class Agent(object): """An Agent works for TFrame Model, handling tensorflow stuffs""" def __init__(self, model, graph=None): # Each agent works on one tensorflow graph with a tensorflow session # .. set to it assert isinstance(model, tfr.models.Model) self._model = model self._session = None self._graph = None # Graph variables self._is_training = None self._init_graph(graph) # An agent saves model and writes summary self._saver = None self._summary_writer = None # An agent holds a default note self._note = Note() context.note = self._note # region : Properties # region : Accessors @property def graph(self): assert isinstance(self._graph, tf.Graph) return self._graph @property def session(self): assert isinstance(self._session, tf.Session) return self._session @property def saver(self): assert isinstance(self._saver, tf.train.Saver) return self._saver @property def summary_writer(self): assert isinstance(self._summary_writer, tf.summary.FileWriter) return self._summary_writer # endregion : Accessors # region : Paths @property def root_path(self): if hub.job_dir == './': return hub.record_dir else: return hub.job_dir @property def note_dir(self): return check_path(self.root_path, hub.note_folder_name, self._model.mark, create_path=hub.export_note) @property def log_dir(self): return check_path(self.root_path, hub.log_folder_name, self._model.mark, create_path=hub.summary) @property def ckpt_dir(self): if hub.specified_ckpt_path is not None: return hub.specified_ckpt_path return check_path(self.root_path, hub.ckpt_folder_name, self._model.mark, create_path=hub.save_model) @property def snapshot_dir(self): return check_path(self.root_path, hub.snapshot_folder_name, self._model.mark, create_path=hub.snapshot) @property def model_path(self): """This property will be used only when checkpoint is to be saved. Old name format: XXXX.model New name example: recurrent.predictor(26.799_epochs)-train-1800 Where XXXX denotes self._model.model_name """ name = '{}.{}'.format(self._model.affix, self._model.model_name.lower()) return os.path.join(self.ckpt_dir, name) @property def gather_path(self): return os.path.join(check_path(self.root_path), hub.gather_file_name) @property def gather_summ_path(self): return os.path.join(check_path(self.root_path), hub.gather_summ_name) # endregion : Paths # endregion : Properties # region : Public Methods def get_status_feed_dict(self, is_training): assert isinstance(is_training, bool) feed_dict = {self._is_training: is_training} return feed_dict def load(self): # TODO: when save_model option is turned off and the user want to # try loading the exist model, set overwrite to False if not hub.save_model and hub.overwrite: return False, 0, None return load_checkpoint(self.ckpt_dir, self.session, self._saver) def save_model(self, rounds=None, suffix=None): """rounds is used only by trainer""" path = self.model_path if rounds is not None: path += '({:.3f}_rounds)'.format(rounds) if suffix is not None: path += '-{}'.format(suffix) save_checkpoint(path, self.session, self._saver, self._model.counter) @with_graph def reset_saver(self): """This method will be used in some very special cased, e.g. for saving train_stats used in dynamic evaluation (krause, 2018) """ self._saver = tf.train.Saver( var_list=self._model.variable_to_save, max_to_keep=2) @with_graph def launch_model(self, overwrite=False): if hub.suppress_logging: console.suppress_logging() # Before launch session, do some cleaning work if overwrite and hub.overwrite: paths = [] if hub.summary: paths.append(self.log_dir) if hub.save_model: paths.append(self.ckpt_dir) if hub.snapshot: paths.append(self.snapshot_dir) if hub.export_note: paths.append(self.note_dir) clear_paths(paths) if hub.summary: self._check_bash() # Launch session on self.graph console.show_status('Launching session ...') config = tf.ConfigProto() if hub.visible_gpu_id is not None: gpu_id = hub.visible_gpu_id if isinstance(gpu_id, int): gpu_id = '{}'.format(gpu_id) elif not isinstance(gpu_id, str): raise TypeError( '!! Visible GPU id provided must be an integer or a string') os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id if not hub.allow_growth: value = hub.gpu_memory_fraction config.gpu_options.per_process_gpu_memory_fraction = value self._session = tf.Session(graph=self._graph, config=config) console.show_status('Session launched') # Prepare some tools self.reset_saver() if hub.summary or hub.hp_tuning: self._summary_writer = tf.summary.FileWriter(self.log_dir) # Initialize all variables self._session.run(tf.global_variables_initializer()) # Set init_val for pruner if necessary # .. if existed model is loaded, variables will be overwritten if hub.prune_on: context.pruner.set_init_val_lottery18() # Try to load exist model load_flag, self._model.counter, self._model.rounds = self.load() # Sanity check if hub.prune_on and hub.pruning_iterations > 0: if not load_flag: raise AssertionError( '!! Model {} should be initialized'.format(self._model.mark)) if not load_flag: assert self._model.counter == 0 # Add graph if hub.summary: self._summary_writer.add_graph(self._session.graph) # Write model description to file if hub.snapshot: description_path = os.path.join(self.snapshot_dir, 'description.txt') write_file(description_path, self._model.description) # Show status console.show_status('New model initiated') elif hub.branch_suffix not in [None, '']: hub.mark += hub.branch_suffix self._model.mark = hub.mark console.show_status('Checkpoint switched to branch `{}`'.format(hub.mark)) self._model.launched = True self.take_notes('Model launched') # Handle structure detail here self._model.handle_structure_detail() return load_flag def shutdown(self): if hub.summary or hub.hp_tuning: self._summary_writer.close() self.session.close() def write_summary(self, summary, step=None): if not hub.summary: return if step is None: assert context.trainer is not None if hub.epoch_as_step and context.trainer.total_rounds is not None: step = int(context.trainer.total_rounds * 1000) else: step = self._model.counter assert isinstance(self._summary_writer, tf.summary.FileWriter) self._summary_writer.add_summary(summary, step) def save_plot(self, fig, filename): imtool.save_plt(fig, '{}/{}'.format(self.snapshot_dir, filename)) # endregion : Public Methods # region : Public Methods for Note # region : For TensorViewer def take_down_scalars_and_tensors(self, scalars, tensors): assert isinstance(scalars, dict) and isinstance(tensors, dict) if hub.epoch_as_step and context.trainer.total_rounds is not None: step = int(context.trainer.total_rounds * 1000) else: step = self._model.counter self._note.take_down_scalars_and_tensors(step, scalars, tensors) # endregion : For TensorViewer # region : For SummaryViewer def put_down_configs(self, th): assert isinstance(th, Config) self._note.put_down_configs(th.key_options) def put_down_criterion(self, name, value): self._note.put_down_criterion(name, value) def gather_to_summary(self): import pickle # Try to load note list into summaries file_path = self.gather_summ_path if os.path.exists(file_path): with open(file_path, 'rb') as f: summary = pickle.load(f) assert len(summary) > 0 else: summary = [] # Add note to list and save note = self._note.tensor_free if hub.gather_only_scalars else self._note summary.append(note) with open(file_path, 'wb') as f: pickle.dump(summary, f, pickle.HIGHEST_PROTOCOL) # Show status console.show_status('Note added to summaries ({} => {}) at `{}`'.format( len(summary) - 1, len(summary), file_path)) # endregion : For SummaryViewer def take_notes(self, content, date_time=True, prompt=None): if not isinstance(content, str): raise TypeError('!! content must be a string') if isinstance(prompt, str): date_time = False content = '{} {}'.format(prompt, content) if date_time: time_str = time.strftime('[{}-{}-%d %H:%M:%S]'.format( time.strftime('%Y')[2:], time.strftime('%B')[:3])) content = '{} {}'.format(time_str, content) self._note.write_line(content) def show_notes(self): console.section('Notes') console.write_line(self._note.content) def export_notes(self, filename='notes'): assert hub.export_note # Export .txt file file_path = '{}/{}.txt'.format(self.note_dir, filename) writer = open(file_path, 'a') writer.write('=' * 79 + '\n') writer.write(self._note.content + '\n') writer.close() # Export .note file file_path = '{}/{}.note'.format(self.note_dir, filename) self._note.save(file_path) console.show_status('Note exported to `{}`'.format(file_path)) def gather_notes(self, take_down_time=False): assert hub.gather_note # If gather file does not exist, create one with open(self.gather_path, 'a'): pass # Gather notes to .txt file line = self._note.content with open(self.gather_path, 'r+') as f: content = f.readlines() f.seek(0) f.truncate() if take_down_time: time_str = time.strftime('[{}-{}-%d %H:%M:%S]'.format( time.strftime('%Y')[2:], time.strftime('%B')[:3])) line = '[{}] {}'.format(time_str, line) f.write(line + '\n') f.write('-' * 79 + '\n') f.writelines(content) # TODO: find a way to update immediately after training is over # Gather notes to .summ file self.gather_to_summary() # endregion : Public Methods for Note # region : Private Methods def _init_graph(self, graph): if graph is not None: assert isinstance(graph, tf.Graph) self._graph = graph else: self._graph = tf.Graph() # Initialize graph variables with self.graph.as_default(): self._is_training = tf.placeholder( dtype=tf.bool, name=pedia.is_training) tf.add_to_collection(pedia.is_training, self._is_training) # TODO # When linking batch-norm layer (and dropout layer), # this placeholder will be got from default graph # self._graph.is_training = self._is_training # assert context.current_graph is not None if not hub.suppress_current_graph: context.current_graph = self._graph # tfr.current_graph = self._graph def _check_bash(self): command = 'tensorboard --logdir=./logs/ --port={}'.format(hub.tb_port) file_path = check_path(self.root_path, create_path=True) file_names = ['win_launch_tensorboard.bat', 'unix_launch_tensorboard.sh'] for file_name in file_names: path = os.path.join(file_path, file_name) if not os.path.exists(path): f = open(path, 'w') f.write(command) f.close()
class Agent(object): """An Agent works for TFrame Model, handling tensorflow stuffs""" def __init__(self, model, graph=None): # Each agent works on one tensorflow graph with a tensorflow session # .. set to it assert isinstance(model, tfr.models.Model) self._model = model self._session = None self._graph = None # Graph variables self._is_training = None self._init_graph(graph) # An agent saves model and writes summary self._saver = None self._summary_writer = None # An agent holds a default note self._note = Note() # region : Properties # region : Accessors @property def graph(self): assert isinstance(self._graph, tf.Graph) return self._graph @property def session(self): assert isinstance(self._session, tf.Session) return self._session @property def saver(self): assert isinstance(self._saver, tf.train.Saver) return self._saver @property def summary_writer(self): assert isinstance(self._summary_writer, tf.summary.FileWriter) return self._summary_writer # endregion : Accessors # region : Paths @property def note_dir(self): return check_path(hub.job_dir, hub.record_dir, hub.note_folder_name, self._model.mark, create_path=hub.export_note) @property def log_dir(self): return check_path(hub.job_dir, hub.record_dir, hub.log_folder_name, self._model.mark, create_path=hub.summary) @property def ckpt_dir(self): return check_path(hub.job_dir, hub.record_dir, hub.ckpt_folder_name, self._model.mark, create_path=hub.save_model) @property def snapshot_dir(self): return check_path(hub.job_dir, hub.record_dir, hub.snapshot_folder_name, self._model.mark, create_path=hub.snapshot) @property def model_path(self): return os.path.join(self.ckpt_dir, '{}.model'.format(self._model.model_name)) # endregion : Paths # endregion : Properties # region : Public Methods def get_status_feed_dict(self, is_training): assert isinstance(is_training, bool) feed_dict = {self._is_training: is_training} return feed_dict def load(self): # TODO: when save_model option is turned off and the user want to # try loading the exist model, set overwrite to False if not hub.save_model and hub.overwrite: return False, 0 return load_checkpoint(self.ckpt_dir, self.session, self._saver) def save_model(self): save_checkpoint(self.model_path, self.session, self._saver, self._model.counter) @with_graph def launch_model(self, overwrite=False): if hub.suppress_logging: console.suppress_logging() # Before launch session, do some cleaning work if overwrite and hub.overwrite: paths = [] if hub.summary: paths.append(self.log_dir) if hub.save_model: paths.append(self.ckpt_dir) if hub.snapshot: paths.append(self.snapshot_dir) if hub.export_note: paths.append(self.note_dir) clear_paths(paths) if hub.summary: self._check_bash() # Launch session on self.graph console.show_status('Launching session ...') config = tf.ConfigProto() if not hub.allow_growth: value = hub.gpu_memory_fraction config.gpu_options.per_process_gpu_memory_fraction = value self._session = tf.Session(graph=self._graph, config=config) console.show_status('Session launched') # Prepare some tools self._saver = tf.train.Saver(var_list=self._model.variable_to_save) if hub.summary or hub.hp_tuning: self._summary_writer = tf.summary.FileWriter(self.log_dir) # Initialize all variables self._session.run(tf.global_variables_initializer()) # Try to load exist model load_flag, self._model.counter = self.load() if not load_flag: assert self._model.counter == 0 # Add graph if hub.summary: self._summary_writer.add_graph(self._session.graph) # Write model description to file if hub.snapshot: description_path = os.path.join(self.snapshot_dir, 'description.txt') write_file(description_path, self._model.description) # Show status console.show_status('New model initiated') self._model.launched = True self.take_notes('Model launched') return load_flag def shutdown(self): if hub.summary or hub.hp_tuning: self._summary_writer.close() self.session.close() def write_summary(self, summary, step=None): if not hub.summary: return if step is None: if hub.epoch_as_step and tfr.trainer.total_rounds is not None: step = int(tfr.trainer.total_rounds * 1000) else: step = self._model.counter assert isinstance(self._summary_writer, tf.summary.FileWriter) self._summary_writer.add_summary(summary, step) def take_notes(self, content, date_time=True, prompt=None): if not isinstance(content, str): raise TypeError('!! content must be a string') if isinstance(prompt, str): date_time = False content = '{} {}'.format(prompt, content) if date_time: time_str = time.strftime('[{}-{}-%d %H:%M:%S]'.format( time.strftime('%Y')[2:], time.strftime('%B')[:3])) content = '{} {}'.format(time_str, content) self._note.write_line(content) def export_notes(self, filename='notes'): assert hub.export_note file_path = '{}/{}.txt'.format(self.note_dir, filename) writer = open(file_path, 'a') writer.write('=' * 79 + '\n') writer.write(self._note.content + '\n') writer.close() console.show_status('Notes exported to {}'.format(file_path)) def show_notes(self): console.section('Notes') console.write_line(self._note.content) def save_plot(self, fig, filename): imtool.save_plt(fig, '{}/{}'.format(self.snapshot_dir, filename)) # endregion : Public Methods # region : Private Methods def _init_graph(self, graph): if graph is not None: assert isinstance(graph, tf.Graph) self._graph = graph else: self._graph = tf.Graph() # Initialize graph variables with self.graph.as_default(): self._is_training = tf.placeholder(dtype=tf.bool, name=pedia.is_training) tf.add_to_collection(pedia.is_training, self._is_training) # TODO # When linking batch-norm layer (and dropout layer), # this placeholder will be got from default graph # self._graph.is_training = self._is_training tfr.current_graph = self._graph def _check_bash(self): file_path = check_path(hub.job_dir, hub.record_dir, create_path=True) file_path = os.path.join(file_path, 'win_launch_tensorboard.bat') if not os.path.exists(file_path): f = open(file_path, 'w') f.write('tensorboard --logdir=./logs/') f.close()