예제 #1
0
파일: trainer.py 프로젝트: zkmartin/tsframe
    def _outer_loop(self):
        hub = self.th
        rnd = 0
        for _ in range(hub.total_outer_loops):
            rnd += 1
            if hub.progress_bar:
                console.section('{} {}'.format(hub.round_name, rnd))
            hub.tic()

            # Begin round (RNN states will be reset here) # TODO
            # self.model.begin_round(th=self.th)

            # Do inner loop
            self._inner_loop(rnd)
            # End of round
            if hub.progress_bar:
                console.show_status(
                    'End of {}. Elapsed time is {:.1f} secs'.format(
                        hub.round_name, hub.toc()))
            # Maybe give a report on metric
            if hub.validation_on:
                self.model.end_round(rnd)
                if self.metric.get_idle_rounds(rnd) > self.th.patience:
                    self.th.raise_stop_flag()
            # Advanced strategy
            self._advanced_strategy(rnd)
            # Export monitor info
            if tfr.monitor.activated: tfr.monitor.export()
            # Maybe save model
            if self._save_model_at_round_end: self._save_model()
            # Early stop
            if hub.stop and self.model.bust(rnd): break

        return rnd
예제 #2
0
파일: wiener.py 프로젝트: zkmartin/nls
    def identify(self, train_set, val_set=None):
        # Sanity check
        if not isinstance(train_set, DataSet):
            raise TypeError('!! train_set must be a DataSet')
        if val_set is not None and not isinstance(val_set, DataSet):
            raise TypeError('!! val_set must be a DataSet')
        if train_set.intensity is None:
            raise TypeError('!! Intensity must not be None')

        # Begin iteration
        console.section('[Wiener] Begin Identification')

        truth_norm, val_input, val_output = None, None, None
        if val_set is not None:
            val_input = val_set.signls[0]
            val_output = val_set.responses[0]
            truth_norm = val_output.norm

        for i in range(len(train_set.signls)):
            input_ = train_set.signls[i]
            output = train_set.responses[i]
            self.cross_correlation(input_, output, train_set.intensity)

            status = 'Round {} finished. '.format(i + 1)

            if val_set is not None:
                pred = self(val_input)
                delta = pred - val_output
                err = delta.norm / truth_norm * 100
                status += 'Error Ratio = {:.3f} %'.format(err)

            console.show_status(status)
            time.sleep(0.2)

        console.show_status('Identification done.')
예제 #3
0
    def run(self, strategy='grid', rehearsal=False, **kwargs):
        """Run script using the given 'strategy'. This method is compatible with
       old version of tframe script_helper, and should be deprecated in the
       future. """
        # Show section
        console.section('Script Information')
        # Show pot configs
        self._show_dict('Pot Configurations', self.configs)
        # Hyper-parameter info will be showed when scroll is set
        self.configure(**kwargs)
        # Do some auto configuring, e.g., set greater_is_better based on the given
        #   criterion
        self._auto_config()
        self.pot.set_scroll(self.configs.get('strategy', strategy),
                            **self.configs)
        # Show common parameters
        self._show_dict('Common Settings', self.common_parameters)
        # >>>>>>> 2c2bb62db734310d5ab5fa0cb66e970e161ddebc

        # Begin iteration
        for i, hyper_params in enumerate(self.pot.scroll.combinations()):
            # Show hyper-parameters
            console.show_info('Hyper-parameters:')
            for k, v in hyper_params.items():
                console.supplement('{}: {}'.format(k, v), level=2)
            # Run process if not rehearsal
            if rehearsal: continue
            console.split()
            # Export log if necessary
            if self.pot.logging_is_needed: self._export_log()
            # Run
            self._run_process(hyper_params, i)
예제 #4
0
    def _show_parameters(self):
        console.section('Parameters')

        def _show_config(name, od):
            assert isinstance(od, OrderedDict)
            if len(od) == 0: return
            console.show_info(name)
            for k, v in od.items():
                console.supplement('{}: {}'.format(k, v), level=2)

        _show_config('Common Settings', self.common_parameters)
        _show_config('Hyper Parameters', self.hyper_parameters)
        _show_config('Constraints', self.constraints)
        print()
예제 #5
0
    def _outer_loop(self):
        hub = self.th
        rnd = 0
        for _ in range(hub.total_outer_loops):
            rnd += 1
            if self.is_online: console.section('Iterations Begin')
            else: console.section('{} {}'.format(hub.round_name, rnd))
            hub.tic()

            # Do inner loop
            self._inner_loop(rnd)
            # End of round
            if hub.progress_bar:
                console.show_status(
                    'End of {}. Elapsed time is {:.1f} secs'.format(
                        hub.round_name, hub.toc()))
            # Inc rounds for models training in epochs
            if self.model.rounds is not None:
                self.model.rounds += 1.0
            # Maybe give a report on metric
            if not self.is_online and hub.validation_on:
                self.model.end_round(rnd)
                if self.key_metric.get_idle_rounds(rnd) > self.th.patience:
                    self.th.raise_stop_flag()

            # Maybe save model (model.rounds var has been increased)
            if self._save_model_at_round_end: self._save_model()

            break_flag = False
            # Early stop via stop flag TODO: needed to be unified
            if hub.stop and self.model.bust(rnd): break_flag = True
            # Force terminate
            if hub.force_terminate: break_flag = True
            # Resurrect if possible
            if break_flag and self._lives > 0:
                self.resurrect(rnd)
                if not self.metrics_manager.resurrected:
                    self.metrics_manager.resurrected = True
                    self.metrics_manager.rar0 = self.metrics_manager.early_stop_criterion
                hub.force_terminate = False
                break_flag = False
            # Break if needed to
            if break_flag: break

        # Out of loop
        if hub.gather_note:
            if self.is_online:
                self.model.agent.put_down_criterion('Total Iterations',
                                                    self.counter)
            else:
                self.model.agent.put_down_criterion('Total Rounds', rnd)

        # Put down final weight fraction if etch is on
        if self.th.etch_on:
            frac = context.pruner.weights_fraction
            self.model.agent.take_notes(
                'Final weight fraction: {:.2f}%'.format(frac))
            self.model.agent.put_down_criterion('Weight Fraction', frac)

        # Evaluate the best model if necessary
        ds_dict = OrderedDict()
        if hub.evaluate_train_set: ds_dict['Train'] = self.training_set
        if hub.evaluate_val_set: ds_dict['Val'] = self.validation_set
        if hub.evaluate_test_set: ds_dict['Test'] = self.test_set
        if len(ds_dict) > 0:
            # Load the best model
            if hub.save_model:
                flag, _, _ = self.model.agent.load()
                assert flag
            # Evaluate the specified data sets
            for name, data_set in ds_dict.items():
                if not isinstance(data_set, TFRData):
                    raise TypeError('!! {} set is not a TFRData'.format(name))
                # TODO
                value = self.model.evaluate_model(
                    data_set, batch_size=hub.eval_batch_size)
                title = '{} {}'.format(name,
                                       self.metrics_manager.eval_slot.name)
                self.model.agent.put_down_criterion(title, value)
                self.model.agent.take_notes('{}: {}'.format(
                    title, hub.decimal_str(value, hub.val_decimals)))

        # Save model here if necessary
        if self._save_model_at_training_end:
            assert len(ds_dict) == 0
            self._save_model()

        return rnd
예제 #6
0
파일: agent.py 프로젝트: rscv5/tframe
 def show_notes(self):
   console.section('Notes')
   console.write_line(self._note.content)
예제 #7
0
    def train(self,
              agent,
              episodes=500,
              print_cycle=0,
              snapshot_cycle=0,
              match_cycle=0,
              rounds=100,
              rate_thresh=1.0,
              shadow=None,
              save_cycle=100,
              snapshot_function=None):
        # Validate agent
        if not isinstance(agent, FMDPAgent):
            raise TypeError('Agent should be a FMDP-agent')

        # Check settings TODO: codes should be reused
        if snapshot_function is not None:
            if not callable(snapshot_function):
                raise ValueError('snapshot_function must be callable')
            self._snapshot_function = snapshot_function

        print_cycle = FLAGS.print_cycle if FLAGS.print_cycle >= 0 else print_cycle
        snapshot_cycle = (FLAGS.snapshot_cycle
                          if FLAGS.snapshot_cycle >= 0 else snapshot_cycle)
        match_cycle = FLAGS.match_cycle if FLAGS.match_cycle >= 0 else match_cycle

        # Show configurations
        console.show_status('Configurations:')
        console.supplement('episodes: {}'.format(episodes))

        # Do some preparation
        if self._session is None:
            self.launch_model()

        assert isinstance(self._graph, tf.Graph)
        with self._graph.as_default():
            if self._merged_summary is None:
                self._merged_summary = tf.summary.merge_all()

        # Set opponent
        if match_cycle > 0:
            if shadow is None:
                self._opponent = FMDRandomPlayer()
                self._opponent.player_name = 'Random Player'
            elif isinstance(shadow, TDPlayer):
                self._opponent = shadow
                self._opponent.player_name = 'Shadow_{}'.format(
                    self._opponent.counter)
            else:
                raise TypeError('Opponent should be an instance of TDPlayer')

        # Begin training iteration
        assert isinstance(agent, FMDPAgent)
        console.section('Begin episodes')
        for epi in range(1, episodes + 1):
            # Initialize variable
            agent.restart()
            if hasattr(agent, 'default_first_move'):
                agent.default_first_move()
            # Record episode start time
            start_time = time.time()
            steps = 0

            state = agent.state
            summary = None
            # Begin current episode
            while not agent.terminated:
                # Make a move
                next_value = self.next_step(agent)
                steps += 1
                # Update model
                state = np.reshape(state, (1, ) + state.shape)
                next_value = np.reshape(np.array(next_value), (1, 1))
                feed_dict = {
                    self.input_[0]: state,
                    self._next_value: next_value
                }
                feed_dict.update(self._get_status_feed_dict(is_training=True))
                assert isinstance(self._session, tf.Session)
                summary, _ = self._session.run(
                    [self._merged_summary, self._update_op], feed_dict)

                state = agent.state

            # End of current episode
            self.counter += 1

            assert isinstance(self._summary_writer, tf.summary.FileWriter)
            self._summary_writer.add_summary(summary, self.counter)

            if print_cycle > 0 and np.mod(self.counter, print_cycle) == 0:
                self._print_progress(epi, start_time, steps, total=episodes)
            if snapshot_cycle > 0 and np.mod(self.counter,
                                             snapshot_cycle) == 0:
                self._snapshot(epi / episodes)
            if match_cycle > 0 and np.mod(self.counter, match_cycle) == 0:
                self._training_match(agent, rounds, epi / episodes,
                                     rate_thresh)
            if np.mod(self.counter, save_cycle) == 0:
                self._save(self.counter)

        # End training
        console.clear_line()
        self._summary_writer.flush()
        self.shutdown()
예제 #8
0
파일: model.py 프로젝트: zkmartin/tframe
    def train(self,
              epoch=1,
              batch_size=128,
              training_set=None,
              validation_set=None,
              print_cycle=0,
              snapshot_cycle=0,
              snapshot_function=None,
              probe=None,
              **kwargs):
        # Check data
        if training_set is not None:
            self._training_set = training_set
        if validation_set is not None:
            self._validation_set = validation_set
        if self._training_set is None:
            raise ValueError('!! Data for training not found')
        elif not isinstance(training_set, TFData):
            raise TypeError(
                '!! Data for training must be an instance of TFData')
        if probe is not None and not callable(probe):
            raise TypeError('!! Probe must be callable')

        if snapshot_function is not None:
            if not callable(snapshot_function):
                raise ValueError('!! snapshot_function must be callable')
            self._snapshot_function = snapshot_function

        self._init_smart_train(validation_set)

        epoch_tol = FLAGS.epoch_tol

        # Get epoch and batch size
        epoch = FLAGS.epoch if FLAGS.epoch > 0 else epoch
        batch_size = FLAGS.batch_size if FLAGS.batch_size > 0 else batch_size
        assert isinstance(self._training_set, TFData)
        self._training_set.set_batch_size(batch_size)

        # Get print and snapshot cycles
        print_cycle = FLAGS.print_cycle if FLAGS.print_cycle >= 0 else print_cycle
        snapshot_cycle = (FLAGS.snapshot_cycle
                          if FLAGS.snapshot_cycle >= 0 else snapshot_cycle)

        # Run pre-train method
        self._pretrain(**kwargs)

        # Show configurations
        console.show_status('Configurations:')
        console.supplement('Training set feature shape: {}'.format(
            self._training_set.features.shape))
        console.supplement('epochs: {}'.format(epoch))
        console.supplement('batch size: {}'.format(batch_size))

        # Do some preparation
        if self._session is None:
            self.launch_model()
        if self._merged_summary is None:
            self._merged_summary = tf.summary.merge_all()

        # Begin iteration
        with self._session.as_default():
            for epc in range(epoch):
                console.section('Epoch {}'.format(epc + 1))
                # Add a new list to metric log if smart_train is on
                if self._train_status['metric_on']: self._metric_log.append([])
                # Record epoch start time
                start_time = time.time()
                while True:
                    # Get data batch
                    data_batch, end_epoch_flag = self._training_set.next_batch(
                        shuffle=FLAGS.shuffle and FLAGS.train)
                    # Increase counter, counter may be used in _update_model
                    self._counter += 1
                    # Update model
                    loss_dict = self._update_model(data_batch, **kwargs)
                    # Print status
                    if print_cycle > 0 and np.mod(self._counter - 1,
                                                  print_cycle) == 0:
                        loss_dict, new_record = self._update_loss_dict(
                            loss_dict, probe)
                        self._print_progress(epc,
                                             start_time,
                                             loss_dict,
                                             data_batch=data_batch)
                        if new_record:
                            self._last_epoch = epc
                            if FLAGS.save_best and epc + 1 >= FLAGS.dont_save_until:
                                if FLAGS.save_model:
                                    self._save(self._counter)
                                    self._inter_cut('[New Record] Model saved')

                    # Snapshot
                    if (FLAGS.snapshot and snapshot_cycle > 0 and np.mod(
                            self._counter - 1, snapshot_cycle) == 0):
                        self._snapshot()
                    # Check flag
                    if end_epoch_flag:
                        if FLAGS.progress_bar: console.clear_line()
                        console.show_status('End of epoch. Elapsed time is '
                                            '{:.1f} secs'.format(time.time() -
                                                                 start_time))
                        break

                # End of epoch
                break_flag = False
                since_last = epc - self._last_epoch
                if since_last == 0: self._train_status['bad_apples'] = 0
                else: self._train_status['bad_apples'] += 1
                save_flag = self._apply_smart_train(
                ) if FLAGS.smart_train else True
                if self._train_status['metric_on']:
                    best_metric = self._session.run(self._best_metric)
                    console.supplement(
                        '[Best {:.3f}] {} epochs since last record appears.'.
                        format(best_metric, since_last))

                if not FLAGS.save_best and save_flag and FLAGS.save_model:
                    self._save(self._counter)
                    console.show_status('Model saved')
                elif since_last >= epoch_tol:
                    break_flag = True

                # Early stop if break flag is true
                if break_flag: break

        # End training
        if FLAGS.progress_bar: console.clear_line()

        # Write HP-tuning metric
        if FLAGS.hpt:
            summary = self._session.run(self._best_metric_sum)
            self._summary_writer.add_summary(summary, self._counter)

        if FLAGS.summary or FLAGS.hpt: self._summary_writer.flush()
예제 #9
0
    return {1: 'linear', 2: 'quadratic', 3: 'cubic',
             4: 'quartic', 5: 'quintic', 6: 'sixtic',
             7: 'septic'}[self.order]


  @single_input
  def _link(self, input_, **kwargs):
    assert isinstance(input_, tf.Tensor)
    if self.coefs is not None: tf.get_variable_scope().reuse_variables()
    # Get input dimension
    D = input_.get_shape().as_list()[1]
    self.neuron_scale = [D] * self.order
    # Get variable
    self.coefs = tf.get_variable('coefs', shape=(D,) * self.order)
    # Calculate output
    result = self.coefs
    for dim in range(self.order - 1, -1, -1):
      shape = [-1, D] + [1] * dim
      result = tf.reshape(input_, shape=shape) * result
      name = ('d{}'.format(dim) if dim > 0
              else '{}_output'.format(self.poly_name))
      result = tf.reduce_sum(result, axis=1, name=name, keep_dims=dim is 0)

    return result


if __name__ == '__main__':
  from tframe import console
  console.section('homogeneous.py test')