Beispiel #2
class Agent(object):
  """An Agent works for TFrame Model, handling tensorflow stuffs"""
  def __init__(self, model, graph=None):
    # Each agent works on one tensorflow graph with a tensorflow session
    # .. set to it
    assert isinstance(model, tfr.models.Model)
    self._model = model
    self._session = None
    self._graph = None
    # Graph variables
    self._is_training = None
    # An agent saves model and writes summary
    self._saver = None
    self._summary_writer = None
    # An agent holds a default note
    self._note = Note()
    context.note = self._note

  # region : Properties

  # region : Accessors

  def graph(self):
    assert isinstance(self._graph, tf.Graph)
    return self._graph

  def session(self):
    assert isinstance(self._session, tf.Session)
    return self._session

  def saver(self):
    assert isinstance(self._saver, tf.train.Saver)
    return self._saver

  def summary_writer(self):
    assert isinstance(self._summary_writer, tf.summary.FileWriter)
    return self._summary_writer

  # endregion : Accessors

  # region : Paths

  def root_path(self):
    if hub.job_dir == './': return hub.record_dir
    else: return hub.job_dir

  def note_dir(self):
    return check_path(self.root_path, hub.note_folder_name,
                      self._model.mark, create_path=hub.export_note)
  def log_dir(self):
    return check_path(self.root_path, hub.log_folder_name,
                      self._model.mark, create_path=hub.summary)
  def ckpt_dir(self):
    if hub.specified_ckpt_path is not None: return hub.specified_ckpt_path
    return check_path(self.root_path, hub.ckpt_folder_name,
                      self._model.mark, create_path=hub.save_model)
  def snapshot_dir(self):
    return check_path(self.root_path, hub.snapshot_folder_name,
                      self._model.mark, create_path=hub.snapshot)
  def model_path(self):
    """This property will be used only when checkpoint is to be saved.
        Old name format: XXXX.model
        New name example: recurrent.predictor(26.799_epochs)-train-1800
        Where XXXX denotes self._model.model_name
    name = '{}.{}'.format(self._model.affix, self._model.model_name.lower())
    return os.path.join(self.ckpt_dir, name)

  def gather_path(self):
    return os.path.join(check_path(self.root_path), hub.gather_file_name)

  def gather_summ_path(self):
    return os.path.join(check_path(self.root_path), hub.gather_summ_name)

  # endregion : Paths

  # endregion : Properties

  # region : Public Methods

  def get_status_feed_dict(self, is_training):
    assert isinstance(is_training, bool)
    feed_dict = {self._is_training: is_training}
    return feed_dict

  def load(self):
    # TODO: when save_model option is turned off and the user want to
    #   try loading the exist model, set overwrite to False
    if not hub.save_model and hub.overwrite: return False, 0, None
    return load_checkpoint(self.ckpt_dir, self.session, self._saver)

  def save_model(self, rounds=None, suffix=None):
    """rounds is used only by trainer"""
    path = self.model_path
    if rounds is not None: path += '({:.3f}_rounds)'.format(rounds)
    if suffix is not None: path += '-{}'.format(suffix)
    save_checkpoint(path, self.session, self._saver, self._model.counter)

  def reset_saver(self):
    """This method will be used in some very special cased, e.g. for
       saving train_stats used in dynamic evaluation (krause, 2018)
    self._saver = tf.train.Saver(
      var_list=self._model.variable_to_save, max_to_keep=2)

  def launch_model(self, overwrite=False):
    if hub.suppress_logging: console.suppress_logging()
    # Before launch session, do some cleaning work
    if overwrite and hub.overwrite:
      paths = []
      if hub.summary: paths.append(self.log_dir)
      if hub.save_model: paths.append(self.ckpt_dir)
      if hub.snapshot: paths.append(self.snapshot_dir)
      if hub.export_note: paths.append(self.note_dir)
    if hub.summary: self._check_bash()

    # Launch session on self.graph
    console.show_status('Launching session ...')
    config = tf.ConfigProto()
    if hub.visible_gpu_id is not None:
      gpu_id = hub.visible_gpu_id
      if isinstance(gpu_id, int): gpu_id = '{}'.format(gpu_id)
      elif not isinstance(gpu_id, str): raise TypeError(
        '!! Visible GPU id provided must be an integer or a string')
      os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
    if not hub.allow_growth:
      value = hub.gpu_memory_fraction
      config.gpu_options.per_process_gpu_memory_fraction = value
    self._session = tf.Session(graph=self._graph, config=config)
    console.show_status('Session launched')
    # Prepare some tools
    if hub.summary or hub.hp_tuning:
      self._summary_writer = tf.summary.FileWriter(self.log_dir)

    # Initialize all variables
    # Set init_val for pruner if necessary
    # .. if existed model is loaded, variables will be overwritten
    if hub.prune_on: context.pruner.set_init_val_lottery18()

    # Try to load exist model
    load_flag, self._model.counter, self._model.rounds = self.load()
    # Sanity check
    if hub.prune_on and hub.pruning_iterations > 0:
      if not load_flag: raise AssertionError(
        '!! Model {} should be initialized'.format(self._model.mark))

    if not load_flag:
      assert self._model.counter == 0
      # Add graph
      if hub.summary: self._summary_writer.add_graph(self._session.graph)
      # Write model description to file
      if hub.snapshot:
        description_path = os.path.join(self.snapshot_dir, 'description.txt')
        write_file(description_path, self._model.description)
      # Show status
      console.show_status('New model initiated')
    elif hub.branch_suffix not in [None, '']:
      hub.mark += hub.branch_suffix
      self._model.mark = hub.mark
      console.show_status('Checkpoint switched to branch `{}`'.format(hub.mark))

    self._model.launched = True
    self.take_notes('Model launched')

    # Handle structure detail here

    return load_flag

  def shutdown(self):
    if hub.summary or hub.hp_tuning:

  def write_summary(self, summary, step=None):
    if not hub.summary: return
    if step is None:
      assert context.trainer is not None
      if hub.epoch_as_step and context.trainer.total_rounds is not None:
        step = int(context.trainer.total_rounds * 1000)
      else: step = self._model.counter
    assert isinstance(self._summary_writer, tf.summary.FileWriter)
    self._summary_writer.add_summary(summary, step)

  def save_plot(self, fig, filename):
    imtool.save_plt(fig, '{}/{}'.format(self.snapshot_dir, filename))

  # endregion : Public Methods

  # region : Public Methods for Note

  # region : For TensorViewer

  def take_down_scalars_and_tensors(self, scalars, tensors):
    assert isinstance(scalars, dict) and isinstance(tensors, dict)
    if hub.epoch_as_step and context.trainer.total_rounds is not None:
      step = int(context.trainer.total_rounds * 1000)
    else: step = self._model.counter
    self._note.take_down_scalars_and_tensors(step, scalars, tensors)

  # endregion : For TensorViewer

  # region : For SummaryViewer

  def put_down_configs(self, th):
    assert isinstance(th, Config)

  def put_down_criterion(self, name, value):
    self._note.put_down_criterion(name, value)

  def gather_to_summary(self):
    import pickle
    # Try to load note list into summaries
    file_path = self.gather_summ_path
    if os.path.exists(file_path):
      with open(file_path, 'rb') as f: summary = pickle.load(f)
      assert len(summary) > 0
    else: summary = []
    # Add note to list and save
    note = self._note.tensor_free if hub.gather_only_scalars else self._note
    with open(file_path, 'wb') as f:
      pickle.dump(summary, f, pickle.HIGHEST_PROTOCOL)
    # Show status
    console.show_status('Note added to summaries ({} => {}) at `{}`'.format(
      len(summary) - 1, len(summary), file_path))

  # endregion : For SummaryViewer

  def take_notes(self, content, date_time=True, prompt=None):
    if not isinstance(content, str):
      raise TypeError('!! content must be a string')
    if isinstance(prompt, str):
      date_time = False
      content = '{} {}'.format(prompt, content)
    if date_time:
      time_str = time.strftime('[{}-{}-%d %H:%M:%S]'.format(
        time.strftime('%Y')[2:], time.strftime('%B')[:3]))
      content = '{} {}'.format(time_str, content)


  def show_notes(self):

  def export_notes(self, filename='notes'):
    assert hub.export_note
    # Export .txt file
    file_path = '{}/{}.txt'.format(self.note_dir, filename)
    writer = open(file_path, 'a')
    writer.write('=' * 79 + '\n')
    writer.write(self._note.content + '\n')
    # Export .note file
    file_path = '{}/{}.note'.format(self.note_dir, filename)
    console.show_status('Note exported to `{}`'.format(file_path))

  def gather_notes(self, take_down_time=False):
    assert hub.gather_note
    # If gather file does not exist, create one
    with open(self.gather_path, 'a'): pass
    # Gather notes to .txt file
    line = self._note.content
    with open(self.gather_path, 'r+') as f:
      content = f.readlines()
      if take_down_time:
        time_str = time.strftime('[{}-{}-%d %H:%M:%S]'.format(
          time.strftime('%Y')[2:], time.strftime('%B')[:3]))
        line = '[{}] {}'.format(time_str, line)
      f.write(line + '\n')
      f.write('-' * 79 + '\n')
      # TODO: find a way to update immediately after training is over
    # Gather notes to .summ file

  # endregion : Public Methods for Note

  # region : Private Methods

  def _init_graph(self, graph):
    if graph is not None:
      assert isinstance(graph, tf.Graph)
      self._graph = graph
    else: self._graph = tf.Graph()
    # Initialize graph variables
    with self.graph.as_default():
      self._is_training = tf.placeholder(
        dtype=tf.bool, name=pedia.is_training)
      tf.add_to_collection(pedia.is_training, self._is_training)

    # TODO
    # When linking batch-norm layer (and dropout layer),
    #   this placeholder will be got from default graph
    # self._graph.is_training = self._is_training
    # assert context.current_graph is not None
    if not hub.suppress_current_graph: context.current_graph = self._graph
    # tfr.current_graph = self._graph

  def _check_bash(self):
    command = 'tensorboard --logdir=./logs/ --port={}'.format(hub.tb_port)
    file_path = check_path(self.root_path, create_path=True)
    file_names = ['win_launch_tensorboard.bat', '']
    for file_name in file_names:
      path = os.path.join(file_path, file_name)
      if not os.path.exists(path):
        f = open(path, 'w')
