Exemple #1
0
 def _end_training(self, rounds):
     if self.th.progress_bar: console.clear_line()
     # If this is a hp-tuning task, write record summary
     if self.th.hp_tuning:
         assert not self.th.summary
         self.key_metric.write_record_summary()
     # Flush summary
     if self.th.summary or self.th.hp_tuning:
         self.model.agent.summary_writer.flush()
     # Take notes
     if self.is_online:
         self.model.agent.take_notes(
             'End training after {} iterations'.format(self.counter))
     else:
         total_round = ('' if self.total_rounds is None else
                        ' ({:.1f} total)'.format(self.total_rounds))
         self.model.agent.take_notes(
             'End training after {} rounds{}'.format(rounds, total_round))
     # Evaluate
     if self._evaluate is not None:
         # Load the best model if necessary
         if self.th.save_model:
             flag, _, _ = self.model.agent.load()
             assert flag
         # Evaluate model
         self._evaluate(self)
     # Show RAS if necessary
     if self.th.lives > 0:
         ras_info = self.metrics_manager.RAR_string
         console.show_status(ras_info)
         self.model.agent.take_notes(ras_info)
Exemple #2
0
 def make_strings(cls,
                  num,
                  unique=True,
                  exclusive=None,
                  embedded=False,
                  multiple=1,
                  verbose=False,
                  interleave=True):
     # Check input
     if exclusive is None: exclusive = []
     elif not isinstance(exclusive, list):
         raise TypeError('!! exclusive must be a list of Reber strings')
     # Make strings
     reber_list = []
     long_token = None
     for i in range(num):
         if interleave:
             long_token = 'T' if long_token in ('P', None) else 'P'
         while True:
             string = ReberGrammar(embedded,
                                   multiple=multiple,
                                   specification=long_token)
             if unique and string in reber_list: continue
             if string in exclusive: continue
             reber_list.append(string)
             break
         if verbose:
             console.clear_line()
             console.print_progress(i + 1, num)
     if verbose: console.clear_line()
     # Return a list of Reber string
     return reber_list
Exemple #3
0
 def _print_progress(self, epi, start_time, steps, **kwargs):
     """Use a awkward way to avoid IDE warning :("""
     console.clear_line()
     console.show_status(
         'Episode {} [{} total] {} steps, Time elapsed = {:.2f} sec'.format(
             epi, self.counter, steps,
             time.time() - start_time))
     console.print_progress(epi, kwargs.get('total'))
Exemple #4
0
    def _inter_cut(self, content, start_time=None):
        # If run on the cloud, do not show progress bar
        if not FLAGS.progress_bar:
            console.show_status(content)
            return

        console.clear_line()
        console.show_status(content)
        console.print_progress(progress=self._training_set.progress,
                               start_time=start_time)
Exemple #5
0
    def _snapshot(self, progress):
        if self._snapshot_function is None:
            return

        filename = 'train_{}_episode'.format(self.counter)
        fullname = "{}/{}".format(self.snapshot_dir, filename)
        self._snapshot_function(fullname)

        console.clear_line()
        console.write_line("[Snapshot] snapshot saved to {}".format(filename))
        console.print_progress(progress=progress)
Exemple #6
0
  def _print_progress(self, epc, start_time, info_dict, **kwargs):
    # Generate loss string
    loss_strings = ['{} = {:.3f}'.format(k, info_dict[k])
                    for k in info_dict.keys()]
    loss_string = ', '.join(loss_strings)

    total_epoch = self._counter / self._training_set.batches_per_epoch
    if FLAGS.progress_bar: console.clear_line()
    console.show_status(
      'Epoch {} [{:.1f} Total] {}'.format(epc + 1, total_epoch, loss_string))
    if FLAGS.progress_bar:
      console.print_progress(progress=self._training_set.progress,
                             start_time=start_time)
Exemple #7
0
 def _synthesize(cls, size, L, N, fixed_length, verbose=False):
     features, targets = [], []
     for i in range(size):
         x, y = engine(L, N, fixed_length)
         features.append(x)
         targets.append(y)
         if verbose:
             console.clear_line()
             console.print_progress(i + 1, size)
     # Wrap data into a SequenceSet
     data_set = SequenceSet(features,
                            summ_dict={'targets': targets},
                            n_to_one=True,
                            name='TemporalOrder')
     return data_set
Exemple #8
0
 def _end_training(self, rounds):
     if self.th.progress_bar: console.clear_line()
     # If this is a hp-tuning task, write record summary
     if self.th.hp_tuning:
         assert not self.th.summary
         self.metric.write_record_summary()
     # Flush summary
     if self.th.summary or self.th.hp_tuning:
         self.model.agent.summary_writer.flush()
     # Take notes
     total_round = '' if self.total_rounds is None else ' ({:.1f} total)'.format(
         self.total_rounds)
     self.model.agent.take_notes('End training after {} rounds{}'.format(
         rounds, total_round))
     # Add metric info into notes
     if self.th.validation_on: self.model.take_down_metric()
Exemple #9
0
    def _training_match(self, agent, rounds, progress, rate_thresh):
        # TODO: inference graph is hard to build under this frame => compromise
        if self._opponent is None:
            return
        assert isinstance(agent, FMDPAgent)

        console.clear_line()
        title = 'Training match with {}'.format(self._opponent.player_name)
        rate = self.compete(agent, rounds, self._opponent, title=title)
        if rate >= rate_thresh and isinstance(self._opponent, TDPlayer):
            # Find an stronger opponent
            self._opponent._load()
            self._opponent.player_name = 'Shadow_{}'.format(
                self._opponent.counter)
            console.show_status('Opponent updated')

        console.print_progress(progress=progress)
Exemple #10
0
    def _download(cls, file_path, url=None):
        import time
        from six.moves import urllib
        # Show status
        file_name = cls._split_path(file_path)[-1]
        console.show_status('Downloading {} ...'.format(file_name))
        start_time = time.time()

        def _progress(count, block_size, total_size):
            console.clear_line()
            console.print_progress(count * block_size, total_size, start_time)

        url = cls.DATA_URL if url is None else url
        file_path, _ = urllib.request.urlretrieve(url, file_path, _progress)
        stat_info = os.stat(file_path)
        console.clear_line()
        console.show_status('Successfully downloaded {} ({} bytes).'.format(
            file_name, stat_info.st_size))
Exemple #11
0
 def make_strings(cls, num, unique=True, exclusive=None, embedded=False,
                  verbose=False):
   # Check input
   if exclusive is None: exclusive = []
   elif not isinstance(exclusive, list):
     raise TypeError('!! exclusive must be a list of Reber strings')
   # Make strings
   reber_list = []
   for i in range(num):
     while True:
       string = ReberGrammar(embedded)
       if unique and string in reber_list: continue
       if string in exclusive: continue
       reber_list.append(string)
       break
     if verbose:
       console.clear_line()
       console.print_progress(i + 1, num)
   if verbose: console.clear_line()
   # Return a list of Reber string
   return reber_list
Exemple #12
0
    def train(self,
              agent,
              episodes=500,
              print_cycle=0,
              snapshot_cycle=0,
              match_cycle=0,
              rounds=100,
              rate_thresh=1.0,
              shadow=None,
              save_cycle=100,
              snapshot_function=None):
        # Validate agent
        if not isinstance(agent, FMDPAgent):
            raise TypeError('Agent should be a FMDP-agent')

        # Check settings TODO: codes should be reused
        if snapshot_function is not None:
            if not callable(snapshot_function):
                raise ValueError('snapshot_function must be callable')
            self._snapshot_function = snapshot_function

        print_cycle = FLAGS.print_cycle if FLAGS.print_cycle >= 0 else print_cycle
        snapshot_cycle = (FLAGS.snapshot_cycle
                          if FLAGS.snapshot_cycle >= 0 else snapshot_cycle)
        match_cycle = FLAGS.match_cycle if FLAGS.match_cycle >= 0 else match_cycle

        # Show configurations
        console.show_status('Configurations:')
        console.supplement('episodes: {}'.format(episodes))

        # Do some preparation
        if self._session is None:
            self.launch_model()

        assert isinstance(self._graph, tf.Graph)
        with self._graph.as_default():
            if self._merged_summary is None:
                self._merged_summary = tf.summary.merge_all()

        # Set opponent
        if match_cycle > 0:
            if shadow is None:
                self._opponent = FMDRandomPlayer()
                self._opponent.player_name = 'Random Player'
            elif isinstance(shadow, TDPlayer):
                self._opponent = shadow
                self._opponent.player_name = 'Shadow_{}'.format(
                    self._opponent.counter)
            else:
                raise TypeError('Opponent should be an instance of TDPlayer')

        # Begin training iteration
        assert isinstance(agent, FMDPAgent)
        console.section('Begin episodes')
        for epi in range(1, episodes + 1):
            # Initialize variable
            agent.restart()
            if hasattr(agent, 'default_first_move'):
                agent.default_first_move()
            # Record episode start time
            start_time = time.time()
            steps = 0

            state = agent.state
            summary = None
            # Begin current episode
            while not agent.terminated:
                # Make a move
                next_value = self.next_step(agent)
                steps += 1
                # Update model
                state = np.reshape(state, (1, ) + state.shape)
                next_value = np.reshape(np.array(next_value), (1, 1))
                feed_dict = {
                    self.input_[0]: state,
                    self._next_value: next_value
                }
                feed_dict.update(self._get_status_feed_dict(is_training=True))
                assert isinstance(self._session, tf.Session)
                summary, _ = self._session.run(
                    [self._merged_summary, self._update_op], feed_dict)

                state = agent.state

            # End of current episode
            self.counter += 1

            assert isinstance(self._summary_writer, tf.summary.FileWriter)
            self._summary_writer.add_summary(summary, self.counter)

            if print_cycle > 0 and np.mod(self.counter, print_cycle) == 0:
                self._print_progress(epi, start_time, steps, total=episodes)
            if snapshot_cycle > 0 and np.mod(self.counter,
                                             snapshot_cycle) == 0:
                self._snapshot(epi / episodes)
            if match_cycle > 0 and np.mod(self.counter, match_cycle) == 0:
                self._training_match(agent, rounds, epi / episodes,
                                     rate_thresh)
            if np.mod(self.counter, save_cycle) == 0:
                self._save(self.counter)

        # End training
        console.clear_line()
        self._summary_writer.flush()
        self.shutdown()
Exemple #13
0
 def _progress(count, block_size, total_size):
     console.clear_line()
     console.print_progress(count * block_size, total_size, start_time)
Exemple #14
0
    def train(self,
              epoch=1,
              batch_size=128,
              training_set=None,
              validation_set=None,
              print_cycle=0,
              snapshot_cycle=0,
              snapshot_function=None,
              probe=None,
              **kwargs):
        # Check data
        if training_set is not None:
            self._training_set = training_set
        if validation_set is not None:
            self._validation_set = validation_set
        if self._training_set is None:
            raise ValueError('!! Data for training not found')
        elif not isinstance(training_set, TFData):
            raise TypeError(
                '!! Data for training must be an instance of TFData')
        if probe is not None and not callable(probe):
            raise TypeError('!! Probe must be callable')

        if snapshot_function is not None:
            if not callable(snapshot_function):
                raise ValueError('!! snapshot_function must be callable')
            self._snapshot_function = snapshot_function

        self._init_smart_train(validation_set)

        epoch_tol = FLAGS.epoch_tol

        # Get epoch and batch size
        epoch = FLAGS.epoch if FLAGS.epoch > 0 else epoch
        batch_size = FLAGS.batch_size if FLAGS.batch_size > 0 else batch_size
        assert isinstance(self._training_set, TFData)
        self._training_set.set_batch_size(batch_size)

        # Get print and snapshot cycles
        print_cycle = FLAGS.print_cycle if FLAGS.print_cycle >= 0 else print_cycle
        snapshot_cycle = (FLAGS.snapshot_cycle
                          if FLAGS.snapshot_cycle >= 0 else snapshot_cycle)

        # Run pre-train method
        self._pretrain(**kwargs)

        # Show configurations
        console.show_status('Configurations:')
        console.supplement('Training set feature shape: {}'.format(
            self._training_set.features.shape))
        console.supplement('epochs: {}'.format(epoch))
        console.supplement('batch size: {}'.format(batch_size))

        # Do some preparation
        if self._session is None:
            self.launch_model()
        if self._merged_summary is None:
            self._merged_summary = tf.summary.merge_all()

        # Begin iteration
        with self._session.as_default():
            for epc in range(epoch):
                console.section('Epoch {}'.format(epc + 1))
                # Add a new list to metric log if smart_train is on
                if self._train_status['metric_on']: self._metric_log.append([])
                # Record epoch start time
                start_time = time.time()
                while True:
                    # Get data batch
                    data_batch, end_epoch_flag = self._training_set.next_batch(
                        shuffle=FLAGS.shuffle and FLAGS.train)
                    # Increase counter, counter may be used in _update_model
                    self._counter += 1
                    # Update model
                    loss_dict = self._update_model(data_batch, **kwargs)
                    # Print status
                    if print_cycle > 0 and np.mod(self._counter - 1,
                                                  print_cycle) == 0:
                        loss_dict, new_record = self._update_loss_dict(
                            loss_dict, probe)
                        self._print_progress(epc,
                                             start_time,
                                             loss_dict,
                                             data_batch=data_batch)
                        if new_record:
                            self._last_epoch = epc
                            if FLAGS.save_best and epc + 1 >= FLAGS.dont_save_until:
                                if FLAGS.save_model:
                                    self._save(self._counter)
                                    self._inter_cut('[New Record] Model saved')

                    # Snapshot
                    if (FLAGS.snapshot and snapshot_cycle > 0 and np.mod(
                            self._counter - 1, snapshot_cycle) == 0):
                        self._snapshot()
                    # Check flag
                    if end_epoch_flag:
                        if FLAGS.progress_bar: console.clear_line()
                        console.show_status('End of epoch. Elapsed time is '
                                            '{:.1f} secs'.format(time.time() -
                                                                 start_time))
                        break

                # End of epoch
                break_flag = False
                since_last = epc - self._last_epoch
                if since_last == 0: self._train_status['bad_apples'] = 0
                else: self._train_status['bad_apples'] += 1
                save_flag = self._apply_smart_train(
                ) if FLAGS.smart_train else True
                if self._train_status['metric_on']:
                    best_metric = self._session.run(self._best_metric)
                    console.supplement(
                        '[Best {:.3f}] {} epochs since last record appears.'.
                        format(best_metric, since_last))

                if not FLAGS.save_best and save_flag and FLAGS.save_model:
                    self._save(self._counter)
                    console.show_status('Model saved')
                elif since_last >= epoch_tol:
                    break_flag = True

                # Early stop if break flag is true
                if break_flag: break

        # End training
        if FLAGS.progress_bar: console.clear_line()

        # Write HP-tuning metric
        if FLAGS.hpt:
            summary = self._session.run(self._best_metric_sum)
            self._summary_writer.add_summary(summary, self._counter)

        if FLAGS.summary or FLAGS.hpt: self._summary_writer.flush()