Ejemplo n.º 1
0
    def launch_model(self, overwrite=False):
        if hub.suppress_logging: console.suppress_logging()
        # Before launch session, do some cleaning work
        if overwrite and hub.overwrite:
            paths = []
            if hub.summary: paths.append(self.log_dir)
            if hub.save_model: paths.append(self.ckpt_dir)
            if hub.snapshot: paths.append(self.snapshot_dir)
            if hub.export_note: paths.append(self.note_dir)
            clear_paths(paths)
        if hub.summary: self._check_bash()

        # Launch session on self.graph
        console.show_status('Launching session ...')
        config = tf.ConfigProto()
        if not hub.allow_growth:
            value = hub.gpu_memory_fraction
            config.gpu_options.per_process_gpu_memory_fraction = value
        self._session = tf.Session(graph=self._graph, config=config)
        console.show_status('Session launched')
        # Prepare some tools
        self._saver = tf.train.Saver(var_list=self._model.variable_to_save)
        if hub.summary or hub.hp_tuning:
            self._summary_writer = tf.summary.FileWriter(self.log_dir)

        # Initialize all variables
        self._session.run(tf.global_variables_initializer())
        # Try to load exist model
        load_flag, self._model.counter = self.load()
        if not load_flag:
            assert self._model.counter == 0
            # Add graph
            if hub.summary: self._summary_writer.add_graph(self._session.graph)
            # Write model description to file
            if hub.snapshot:
                description_path = os.path.join(self.snapshot_dir,
                                                'description.txt')
                write_file(description_path, self._model.description)
            # Show status
            console.show_status('New model initiated')

        self._model.launched = True
        self.take_notes('Model launched')
        return load_flag
Ejemplo n.º 2
0
    def launch_model(self, overwrite=False):
        # Check flags
        FLAGS.cloud = FLAGS.cloud or "://" in self.job_dir
        FLAGS.progress_bar = FLAGS.progress_bar and not FLAGS.cloud
        FLAGS.summary = FLAGS.summary and not FLAGS.hpt
        FLAGS.save_model = FLAGS.save_model and not FLAGS.hpt
        FLAGS.snapshot = FLAGS.snapshot and not FLAGS.cloud

        # Before launch session, do some cleaning work
        if overwrite and FLAGS.train and not FLAGS.cloud:
            paths = []
            if FLAGS.summary: paths.append(self.log_dir)
            if FLAGS.save_model: paths.append(self.ckpt_dir)
            if FLAGS.snapshot: paths.append(self.snapshot_dir)
            clear_paths(paths)

        console.show_status('Launching session ...')
        self._session = tf.Session(graph=self._graph)
        console.show_status('Session launched')
        self._saver = tf.train.Saver()
        if FLAGS.summary or FLAGS.hpt:
            self._summary_writer = tf.summary.FileWriter(self.log_dir)
        # Try to load exist model
        load_flag, self._counter = False, 0
        if FLAGS.save_model: load_flag, self._counter = self._load()
        if not load_flag:
            assert self._counter == 0
            # If checkpoint does not exist, initialize all variables
            self._session.run(tf.global_variables_initializer())
            # Add graph
            if FLAGS.summary:
                self._summary_writer.add_graph(self._session.graph)
            # Write model description to file
            if FLAGS.snapshot:
                description_path = os.path.join(self.snapshot_dir,
                                                'description.txt')
                write_file(description_path, self.description)

        return load_flag
Ejemplo n.º 3
0
  def launch_model(self, overwrite=False):
    if hub.suppress_logging: console.suppress_logging()
    # Before launch session, do some cleaning work
    if overwrite and hub.overwrite:
      paths = []
      if hub.summary: paths.append(self.log_dir)
      if hub.save_model: paths.append(self.ckpt_dir)
      if hub.snapshot: paths.append(self.snapshot_dir)
      if hub.export_note: paths.append(self.note_dir)
      clear_paths(paths)
    if hub.summary: self._check_bash()

    # Launch session on self.graph
    console.show_status('Launching session ...')
    config = tf.ConfigProto()
    if hub.visible_gpu_id is not None:
      gpu_id = hub.visible_gpu_id
      if isinstance(gpu_id, int): gpu_id = '{}'.format(gpu_id)
      elif not isinstance(gpu_id, str): raise TypeError(
        '!! Visible GPU id provided must be an integer or a string')
      os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
    if not hub.allow_growth:
      value = hub.gpu_memory_fraction
      config.gpu_options.per_process_gpu_memory_fraction = value
    self._session = tf.Session(graph=self._graph, config=config)
    console.show_status('Session launched')
    # Prepare some tools
    self.reset_saver()
    if hub.summary or hub.hp_tuning:
      self._summary_writer = tf.summary.FileWriter(self.log_dir)

    # Initialize all variables
    self._session.run(tf.global_variables_initializer())
    # Set init_val for pruner if necessary
    # .. if existed model is loaded, variables will be overwritten
    if hub.prune_on: context.pruner.set_init_val_lottery18()

    # Try to load exist model
    load_flag, self._model.counter, self._model.rounds = self.load()
    # Sanity check
    if hub.prune_on and hub.pruning_iterations > 0:
      if not load_flag: raise AssertionError(
        '!! Model {} should be initialized'.format(self._model.mark))

    if not load_flag:
      assert self._model.counter == 0
      # Add graph
      if hub.summary: self._summary_writer.add_graph(self._session.graph)
      # Write model description to file
      if hub.snapshot:
        description_path = os.path.join(self.snapshot_dir, 'description.txt')
        write_file(description_path, self._model.description)
      # Show status
      console.show_status('New model initiated')
    elif hub.branch_suffix not in [None, '']:
      hub.mark += hub.branch_suffix
      self._model.mark = hub.mark
      console.show_status('Checkpoint switched to branch `{}`'.format(hub.mark))

    self._model.launched = True
    self.take_notes('Model launched')

    # Handle structure detail here
    self._model.handle_structure_detail()

    return load_flag