示例#1
0
    def run(self, strategy='grid', rehearsal=False, **kwargs):
        """Run script using the given 'strategy'. This method is compatible with
       old version of tframe script_helper, and should be deprecated in the
       future. """
        # Show section
        console.section('Script Information')
        # Show pot configs
        self._show_dict('Pot Configurations', self.configs)
        # Hyper-parameter info will be showed when scroll is set
        self.configure(**kwargs)
        # Do some auto configuring, e.g., set greater_is_better based on the given
        #   criterion
        self._auto_config()
        self.pot.set_scroll(self.configs.get('strategy', strategy),
                            **self.configs)
        # Show common parameters
        self._show_dict('Common Settings', self.common_parameters)
        # >>>>>>> 2c2bb62db734310d5ab5fa0cb66e970e161ddebc

        # Begin iteration
        for i, hyper_params in enumerate(self.pot.scroll.combinations()):
            # Show hyper-parameters
            console.show_info('Hyper-parameters:')
            for k, v in hyper_params.items():
                console.supplement('{}: {}'.format(k, v), level=2)
            # Run process if not rehearsal
            if rehearsal: continue
            console.split()
            # Export log if necessary
            if self.pot.logging_is_needed: self._export_log()
            # Run
            self._run_process(hyper_params, i)
示例#2
0
文件: net.py 项目: winkywow/tframe
    def _get_extra_loss(self):
        loss_tensor_list = context.loss_tensor_list
        assert isinstance(loss_tensor_list, list)

        # (1) Add customized losses
        customized_loss = self._get_customized_loss()
        if customized_loss:
            loss_tensor_list += customized_loss

        # (2) Add regularized losses
        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        loss_tensor_list.extend(reg_losses)
        # (2-A) Add global l2 loss
        if hub.global_l2_penalty > 0:
            loss_tensor_list.append(hub.global_l2_penalty * tf.add_n(
                [tf.nn.l2_loss(v) for v in self.var_list if v.trainable]))

        # Add-up all extra losses
        if loss_tensor_list:
            result = tf.add_n(loss_tensor_list, 'extra_loss')
        else:
            result = None

        # Show loss list (usually for debugging)
        if hub.show_extra_loss_info and loss_tensor_list:
            console.show_info('Extra losses:')
            for loss_tensor in loss_tensor_list:
                assert isinstance(loss_tensor, tf.Tensor)
                console.supplement(loss_tensor.name, level=2)
            console.split()
        return result
示例#3
0
def on_key_press(viewer, event):
    # Sanity check
    assert isinstance(event, tk.Event)
    assert isinstance(viewer, centre.SummaryViewer)

    key_symbol = getattr(event, 'keysym')
    if viewer.in_debug_mode:
        on_key_press_debug(viewer, key_symbol)

    if key_symbol == 'quoteleft':
        console.show_status('Active flags:', symbol='::')
        for k, v in viewer.config_panel.active_config_dict.items():
            console.supplement('{}: {}'.format(k, v), level=2)
    elif key_symbol in ('h', 'k'):
        viewer.header.move_cursor(-1)
    elif key_symbol in ('l', 'j'):
        viewer.header.move_cursor(1)
    elif key_symbol == 'n':
        viewer.criteria_panel.move_between_groups(1)
    elif key_symbol == 'p':
        viewer.criteria_panel.move_between_groups(-1)
    elif key_symbol == 'space':
        viewer.header.show_selected_note_content()
    elif key_symbol == 'Return':
        viewer.header.on_label_detail_click()
    elif key_symbol == 's':
        file_name = tk.filedialog.asksaveasfilename(filetypes=[('Note file',
                                                                '.note')],
                                                    initialfile='untitled',
                                                    defaultextension='.note')
        if file_name is not None:
            viewer.header.save_selected_note(file_name)
示例#4
0
    def evaluate_model(self, data, batch_size=None, dynamic=False, **kwargs):
        """The word `evaluate` in this method name is different from that in
       `self.evaluate` method. Here only eval_metric will be evaluated and
       the result will be printed on terminal."""
        # Check metric
        if not self.eval_metric.activated:
            raise AssertionError('!! Metric not defined')
        # Do dynamic evaluation if necessary
        if dynamic:
            from tframe.trainers.eval_tools.dynamic_eval import DynamicEvaluator as de
            de.dynamic_evaluate(self, data, kwargs.get('val_set', None),
                                kwargs.get('delay', None))
            return
        # If hub.val_progress_bar is True, this message will be showed in
        #   model.evaluate method
        if not hub.val_progress_bar:
            console.show_status('Evaluating on {} ...'.format(data.name))
        # use val_progress_bar option here temporarily
        result = self.validate_model(
            data, batch_size, allow_sum=False,
            verbose=hub.val_progress_bar)[self.eval_metric]
        console.supplement('{} = {}'.format(
            self.eval_metric.symbol, hub.decimal_str(result,
                                                     hub.val_decimals)))

        return result
示例#5
0
 def _show_trend(self):
     tendency = ''
     if len(self._trend) > 0:
         for i, ratio in enumerate(self._trend):
             if i > 0: tendency += ', '
             tendency += '[{}]{:.1f}%'.format(i + 1, ratio)
     if tendency != '': console.supplement(tendency, level=2)
示例#6
0
 def compete(self, agent, rounds, opponent, title='Competition'):
     console.show_status('[{}]'.format(title))
     assert isinstance(agent, FMDPAgent)
     rate, reports = agent.compete([self, opponent], rounds)
     for report in reports:
         console.supplement(report)
     return rate
示例#7
0
文件: net.py 项目: ssh352/tframe
  def _get_extra_loss(self):
    loss_tensor_list = context.loss_tensor_list
    assert isinstance(loss_tensor_list, list)

    # (1) Add customized losses
    customized_loss = self._get_customized_loss()
    if customized_loss:
      loss_tensor_list += customized_loss

    # (2) Add regularizer losses
    loss_tensor_list += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

    # Add-up all extra losses
    if loss_tensor_list:
      result = tf.add_n(loss_tensor_list, 'extra_loss')
    else: result = None

    # Show loss list (usually for debugging)
    if hub.show_extra_loss_info and loss_tensor_list:
      console.show_info('Extra losses:')
      for loss_tensor in loss_tensor_list:
        assert isinstance(loss_tensor, tf.Tensor)
        console.supplement(loss_tensor.name, level=2)
      console.split()
    return result
示例#8
0
 def _show_configurations(self):
     console.show_status('Configurations:')
     self.model.agent.take_notes('Configurations:', date_time=False)
     for config in self.th.config_strings:
         console.supplement(config)
         self.model.agent.take_notes('.. {}'.format(config),
                                     date_time=False)
示例#9
0
def verify(vns, wn, system, test_set):
  # Show linear coefficients
  console.show_status(
    'System linear coefs = {}'.format(system.kernels.linear_coefs))
  console.show_status('VN linear coefs:')
  for strength, vn in vns.items():
    console.supplement('vn_{:.2f}: {}'.format(strength, vn.nn.linear_coefs))
  # Generate system output
  signal, system_output = test_set.signls[0], test_set.responses[0]
  # Wiener error ratio
  wiener_output, wiener_delta = None, None
  if WIENER_ON:
    wiener_output = wn(signal)
    wiener_delta = system_output - wiener_output
    wiener_ratio = wiener_delta.norm / system_output.norm * 100
    console.show_status('Wiener err ratio = {:.2f} %'.format(wiener_ratio))
  # VN error ratio
  console.show_status('VN error ratio:')
  best_str, best_ratio, best_delta, best_output = None, 9999, None, None
  for strength, vn in vns.items():
    vn_output = vn(signal)
    vn_delta = system_output - vn_output
    vn_ratio = vn_delta.norm / system_output.norm * 100
    if strength == 0 or vn_ratio < best_ratio:
      best_str, best_ratio = strength, vn_ratio
      best_delta, best_output = vn_delta, vn_output
    console.supplement('VN_{:.2f} err ratio = {:.2f} %'.format(
      strength, vn_ratio))
  # Plot
  if PLOT: plot(system_output, wiener_output, wiener_delta,
                best_output, best_delta, best_str)
示例#10
0
    def _on_key_press(self, event):
        assert isinstance(event, tk.Event)

        flag = False
        if event.keysym == 'Escape':
            self.form.quit()
        elif event.keysym == 'j':
            flag = self._move_cursor(1)
        elif event.keysym == 'k':
            flag = self._move_cursor(-1)
        elif event.keysym == 'quoteleft':
            console.show_status('Widgets sizes:')
            for k in self.__dict__.keys():
                item = self.__dict__[k]
                if isinstance(item, tk.Widget) or k == 'form':
                    str = '[{}] {}: {}x{}'.format(item.__class__, k,
                                                  item.winfo_height(),
                                                  item.winfo_width())
                    console.supplement(str)
        elif event.keysym == 'Tab':
            console.show_status('Data:')
            for k in self.dataset._data.keys():
                console.supplement('{}: {}'.format(
                    k, self.dataset._data[k].shape))
        elif event.keysym == 'space':
            self._resize()
        else:
            # console.show_status(event.keysym)
            pass

        # If needed, refresh image viewer
        if flag:
            self.refresh()
示例#11
0
 def _on_detail_btn_click(self):
   if len(self.value_list) == 0:
     console.show_status('No notes with these criteria found under the '
                         'corresponding configures', '::')
     return
   console.show_status('{}:'.format(self.name), '::')
   for v in np.sort(self.value_list):
     console.supplement('{}'.format(v), level=2)
示例#12
0
def homogeneous_check(model, order, input_, output):
  console.show_status('Checking homogeneous system')
  alpha = 2
  console.supplement('Alpha = {}'.format(alpha))
  truth = (alpha ** order) * output
  delta = model(alpha * input_) - truth
  ratio = delta.norm / truth.norm * 100
  console.supplement('Error Ratio = {:.2f} %'.format(ratio))
示例#13
0
 def evaluate_model(self, data, batch_size=None, **kwargs):
     # Check metric
     if not self.metric.activated:
         raise AssertionError('!! Metric not defined')
     # Show status
     console.show_status('Evaluating {} ...'.format(data.name))
     result = self.validate_model(data, batch_size,
                                  allow_sum=False)[self.metric]
     console.supplement('{} = {:.3f}'.format(self.metric.symbol, result))
示例#14
0
 def _validate(self, data_set, batch_size=1, num_steps=1000):
     result_dict = self.model.validate_model(
         data_set,
         batch_size=batch_size,
         verbose=True,
         num_steps=num_steps,
         seq_detail=th.val_info_splits > 0)
     for name, val in result_dict.items():
         console.supplement('{} = {}'.format(
             name, th.decimal_str(val, th.val_decimals)))
示例#15
0
 def set_hyper_parameters(self, eta, lambd):
     assert isinstance(eta, float) and isinstance(lambd, float)
     assert eta > 0 and lambd > 0
     self._eta, self._lambda = eta, lambd
     # Set eta and lambda to graph
     self._model.session.run([self._set_eta_op, self._set_lambda_op],
                             feed_dict={
                                 self._eta_placeholder: eta,
                                 self._lambda_placeholder: lambd
                             })
     # Show status
     console.show_info('Hyper parameters for dynamic evaluation updated:')
     console.supplement('eta = {}, lambda = {}'.format(eta, lambd))
示例#16
0
def load_balanced_data(data_dir, train_size=20, validation_size=2,
                       init_f=None, round_len_f=None):
  # Load data from data_dir
  data_set = GPATBigData.load(data_dir, csv_path=core.v_train_csv_path,
                              lb_sheet_path=core.label_sheet_path)
  assert isinstance(data_set, GPATBigData)
  assert data_set.with_labels

  # Pop subsets from data_set
  train_set = data_set.pop_subset(
    train_size, length_prone='long', name='train_set')
  val_set = data_set.pop_subset(
    validation_size, length_prone='short', name='validation_set')
  train_set.init_f = init_f
  train_set.round_len_f = round_len_f
  val_set.init_f = init_f
  val_set.round_len_f = round_len_f

  # Show status
  console.show_status('Train set (size={})'.format(train_set.size))
  console.supplement('min_len = {}'.format(train_set.min_length))
  console.supplement('max_len = {}'.format(train_set.max_length))
  console.show_status('Validation set (size={})'.format(val_set.size))
  console.supplement('min_len = {}'.format(val_set.min_length))
  console.supplement('max_len = {}'.format(val_set.max_length))

  return train_set.merge_to_signal_set(), val_set
示例#17
0
 def set_search_space(self, hyper_params):
     assert isinstance(hyper_params, list)
     if len(hyper_params) == 0: return
     # Find appropriate HP type if different types are allowed
     if self.enable_hp_types:
         hyper_params = [hp.seek_myself() for hp in hyper_params]
     # Show hyper-parameters setting
     console.show_info('Hyper Parameters -> {}'.format(self.name))
     for hp in hyper_params:
         assert isinstance(hp, HyperParameter)
         assert isinstance(hp, self.valid_HP_types)
         console.supplement('{}: {}'.format(hp.name, hp.option_str),
                            level=2)
         self.hyper_params[hp.name] = hp
示例#18
0
    def build(self, **kwargs):

        # Smooth out flags before important actions
        hub.smooth_out_conflicts()
        # Initialize pruner if necessary
        if any([
                hub.prune_on, hub.weights_mask_on, hub.etch_on,
                hub.force_to_use_pruner
        ]):
            # import here to prevent circular import (temporarily)
            from tframe.operators.prune.pruner import Pruner
            tfr.context.pruner = Pruner(self)
        # If optimizer if not provided here, try hub.get_optimizer()
        #   this requires that th.optimizer and th.learning_rate have been provided
        if 'optimizer' not in kwargs: kwargs['optimizer'] = hub.get_optimizer()
        # Call successor's _build method
        self._build(**kwargs)
        # Initialize monitor
        self._init_monitor()
        # Set built flag
        self._built = True
        # Show build info
        console.show_status('Model built successfully:')
        self.agent.take_notes('Model built successfully')
        self.agent.take_notes('Structure:', date_time=False)
        # Description may be a model structure
        description = self.description
        if not isinstance(description, (tuple, list)):
            description = [description]
        for line in description:
            assert isinstance(line, str)
            console.supplement(line)
            self.agent.take_notes(line, date_time=False)

        # Add metric slot to update group
        batch_metric = kwargs.get('batch_metric', [])
        if batch_metric:
            if not isinstance(batch_metric, (tuple, list)):
                batch_metric = [batch_metric]
            for metric_str in batch_metric:
                assert isinstance(metric_str, str)
                metric_slot = self.metrics_manager.get_slot_by_name(metric_str)
                self._update_group.add(metric_slot)

        # Register eval_metric if provided
        eval_metric = kwargs.get('eval_metric', None)
        if eval_metric is not None:
            assert isinstance(eval_metric, str)
            self.metrics_manager.register_eval_slot(eval_metric)
示例#19
0
    def _advanced_strategy(self, rnd):
        """This method will be called after Metric.end_round method"""
        if not self.th.smart_train: return
        if self.key_metric.trend_is_promising and self.th.bad_apples > 0:
            self.th.bad_apples -= 1
        if self.key_metric.get_idle_rounds(rnd) > 0:
            self.th.bad_apples += 1
            if self.th.bad_apples > self.th.max_bad_apples:
                # Decay learning rate and reset bad apples
                # self.model.agent.load()
                self.lr = self.model.tune_lr(coef=self.th.lr_decay)
                self.th.bad_apples = 0
        else:
            # Reset bad apples when new record appears
            self.th.bad_apples = 0

        # Show bad apple count
        console.supplement('{} bad apples found (lr = {:.2e})'.format(
            self.th.bad_apples, self.lr))
示例#20
0
 def print_latest_stats(self, prompt='[Validate]', decimals=3):
     assert isinstance(prompt, str)
     stats_dict = self.latest_stats_dict
     assert isinstance(stats_dict, OrderedDict)
     for data_set, scalar_dict in stats_dict.items():
         # assert isinstance(data_set, DataSet)
         assert isinstance(scalar_dict, OrderedDict)
         console.show_status('On {}'.format(data_set.name), prompt)
         for slot, value in scalar_dict.items():
             info = ('{} = {:.' + str(decimals) + 'f}').format(
                 slot.symbol, value)
             # Look up for suffix in note
             note_key = (data_set, slot)
             if note_key in self.note.keys():
                 info = '{} {}'.format(info, self.note[note_key])
                 if slot is self.early_stop_slot: info = '(*) ' + info
             # Supplement
             console.supplement(info)
     # Recover progress bar if necessary
     self.trainer.recover_progress()
示例#21
0
    def run(self, times=1, save=False, mark='', rehearsal=False):
        if self._sys_runs is not None:
            times = checker.check_positive_integer(self._sys_runs)
            console.show_status('Run # set to {}'.format(times))
        # Set the corresponding flags if save
        if save:
            self.common_parameters['save_model'] = True
        # Show parameters
        self._show_parameters()
        # Begin iteration
        counter = 0
        for run_id in range(times):
            history = []
            for hyper_dict in self._hyper_parameter_dicts():
                # Set counter here
                counter += 1
                # Grand self._add_script_suffix the highest priority
                if self._add_script_suffix is not None:
                    save = self._add_script_suffix
                if save:
                    self.common_parameters['script_suffix'] = '_{}{}'.format(
                        mark, counter)

                params = self._get_all_configs(hyper_dict)
                self._apply_constraints(params)

                params_list = self._get_config_strings(params)
                params_string = ' '.join(params_list)
                if params_string in history: continue
                history.append(params_string)
                console.show_status(
                    'Loading task ...',
                    '[Run {}/{}][{}]'.format(run_id + 1, times, len(history)))
                console.show_info('Hyper-parameters:')
                for k, v in hyper_dict.items():
                    console.supplement('{}: {}'.format(k, v))
                if not rehearsal:
                    call([self._python_cmd, self.module_name] + params_list)
                    print()
            # End of the run
            if rehearsal: return
示例#22
0
    def end_round(self, rnd):
        new_record = False
        assert isinstance(self._metric_logs[-1], list)
        if len(self._metric_logs[-1]) == 0: return new_record

        current_metrics = self._metric_logs.pop()
        metric_mean = np.mean(current_metrics)
        self._metric_logs.append(metric_mean)

        trend = []
        for i in range(min(self.memory, len(self._metric_logs) - 1)):
            hist_mean = self._metric_logs[-(i + 2)]
            # assert hist_mean >= 0.0  # Metric can be negtive TODO X
            trend.append((metric_mean - hist_mean) / hist_mean * 100)
        # trend = [re(rnd-1), re(rnd-2), ..., re(rnd-memory)]
        self._trend = trend

        # Update best mean metric
        mean_record = self.mean_record
        if (self._mean_record.never_assigned
                or self.is_better_than(metric_mean, mean_record)):
            self.mean_record = metric_mean
            mean_record = metric_mean
            new_record = True

        # Show metric mean status
        # TODO: console access should be somehow controlled
        token = 'min' if self.lower_is_better else 'max'
        console.supplement('E[{}] = {:.3f}, {}(E[{}]) = {:.3f}'.format(
            self.symbol, metric_mean, token, self.symbol, mean_record))
        self._show_trend()

        # Show record
        console.supplement(
            '[Best {:.3f}] {} rounds since last record appears.'.format(
                self.record, self.get_idle_rounds(rnd)))

        # Append log container for new round
        self._metric_logs.append([])

        return new_record
示例#23
0
 def build(self, optimizer=None, **kwargs):
     # Smooth out flags before important actions
     hub.smooth_out_conflicts()
     #
     self._build(optimizer=optimizer, **kwargs)
     # Initialize monitor
     self._init_monitor()
     # Set built flag
     self._built = True
     # Show build info
     console.show_status('Model built successfully:')
     description = self.description
     if not isinstance(description, (tuple, list)):
         description = [description]
     for line in description:
         assert isinstance(line, str)
         console.supplement(line)
     # Maybe take some notes
     self.agent.take_notes('Model built successfully')
     self.agent.take_notes('Structure:', date_time=False)
     for line in description:
         self.agent.take_notes(line, date_time=False)
示例#24
0
    def set_constraints(self, constraints):
        assert isinstance(constraints, dict)
        if len(constraints) == 0: return
        # Show constraints
        console.show_info('Constraints')
        for cond, cons in constraints.items():
            assert isinstance(cond, tuple) and isinstance(cons, dict)
            # Make sure each value in conditions with multi-value has been registered
            # Otherwise ambiguity will arise during constraint applying
            for k, v in cons.items():
                if isinstance(v, (tuple, set, list)):
                    for choice in v:
                        if choice not in self.hyper_params[k].choices:
                            raise AssertionError(
                                '!! Value {} for {} when {} should be registered first.'
                                .format(choice, k, cond))

            # Consider to skip some constraint settings here, i.e., when condition
            #   will never be satisfied

            console.supplement('{}: {}'.format(cond, cons), level=2)
            self.constraints[cond] = cons
示例#25
0
 def _dynamic_eval(self,
                   data_set,
                   lr,
                   lambd,
                   prompt='[Dynamic Evaluation]'):
     assert isinstance(data_set, DataSet)
     # console.show_status('lr = {}, lambd = {}'.format(lr, lambd), prompt)
     console.show_status('', prompt)
     # Reset parameters
     self.optimizer.reset_parameters()
     # Set HP to optimizer
     self.optimizer.set_hyper_parameters(lr, lambd)
     # Do dynamic evaluation
     output = self.model.evaluate(self._dynamic_fetches,
                                  data_set,
                                  batch_size=1,
                                  verbose=True,
                                  num_steps=th.de_num_steps)[0]
     assert isinstance(self._quantity, Quantity)
     metric = self._quantity.apply_np_summ_method(output)
     console.supplement('Dynamic {} = {}'.format(
         self._quantity.name, th.decimal_str(metric, th.val_decimals)))
     return metric
示例#26
0
文件: demo_tlp.py 项目: zkmartin/nls
def evaluate(u, y):
  system_output = y[5:]
  model_output = model(u)[5:]
  delta = system_output - model_output
  rms_truth = float(np.sqrt(np.mean(system_output * system_output)))

  val = delta.average
  pct = val / rms_truth * 100
  console.supplement('E[err] = {:.4f} ({:.3f}%)'.format(val, pct))

  val = float(np.std(delta))
  pct = val / rms_truth * 100
  console.supplement('STD[err] = {:.4f} ({:.3f}%)'.format(val, pct))

  val = float(np.sqrt(np.mean(delta * delta)))
  pct = val / rms_truth * 100
  console.supplement('RMS[err] = {:.4f} ({:.3f}%)'.format(val, pct))

  console.supplement('RMS[truth] = {:.4f}'.format(rms_truth))

  return system_output, model_output, delta
示例#27
0
文件: model.py 项目: zkmartin/tframe
    def _apply_smart_train(self):
        memory = 4
        lr_decay = FLAGS.lr_decay

        save_flag = False

        # At the end of each epoch, analyze metric log
        assert isinstance(self._metric_log[-1], list)
        metric_mean = np.mean(self._metric_log.pop())
        self._metric_log.append(metric_mean)
        history = []
        for i in range(min(memory, len(self._metric_log) - 1)):
            hist_mean = self._metric_log[-(i + 2)]
            assert hist_mean > 0
            history.append((metric_mean - hist_mean) / hist_mean * 100)

        # Refresh best mean metric
        best_mean_metric = self._session.run(self._best_mean_metric)
        if metric_mean < best_mean_metric or best_mean_metric < 0:
            save_flag = True
            self._session.run(tf.assign(self._best_mean_metric, metric_mean))
            best_mean_metric = metric_mean

        # Show status
        tendency = ''
        if len(history) > 0:
            for i, ratio in enumerate(history):
                if i > 0: tendency += ', '
                tendency += '[{}]{:.1f}%'.format(i + 1, ratio)
        console.supplement(
            'E[metric] = {:.3f}, min(E[metric]) = {:.3f}'.format(
                metric_mean, best_mean_metric))
        if tendency != '': console.supplement(tendency, level=2)

        # Tune learning rate TODO: smart train will only be applied here
        if len(self._metric_log
               ) >= memory + 1 and self._train_status['metric_on']:
            if all(np.array(history) < 0
                   ) and self._train_status['bad_apples'] > 0:
                self._train_status['bad_apples'] -= 1
            console.supplement('{} bad apples found'.format(
                self._train_status['bad_apples']))
            if self._train_status['bad_apples'] > memory and FLAGS.smart_train:
                self._tune_lr(lr_decay)
                self._train_status['bad_apples'] = 0
                if not FLAGS.save_best:
                    console.show_status('save_best option has been turned on')
                FLAGS.save_best = True

        return save_flag
示例#28
0
    def _on_key_press(self, event):
        assert isinstance(event, tk.Event)
        big_step = max(self.data_set.size // 10, 1)

        flag = False
        if event.keysym == 'Escape':
            self.form.quit()
        elif event.keysym == 'j':
            flag = self._move_cursor(1)
        elif event.keysym == 'k':
            flag = self._move_cursor(-1)
        elif event.keysym == 'J':
            flag = self._move_cursor(big_step)
        elif event.keysym == 'K':
            flag = self._move_cursor(-big_step)
        elif event.keysym == 'l':
            flag = self._move_cursor(1, horizontal=True)
        elif event.keysym == 'h':
            flag = self._move_cursor(-1, horizontal=True)
        elif event.keysym == 'quoteleft':
            console.show_status('Widgets sizes:')
            for k in self.__dict__.keys():
                item = self.__dict__[k]
                if isinstance(item, tk.Widget) or k == 'form':
                    str = '[{}] {}: {}x{}'.format(item.__class__, k,
                                                  item.winfo_height(),
                                                  item.winfo_width())
                    console.supplement(str)
        elif event.keysym == 'Tab':
            if self.data_set is None: return
            assert isinstance(self.data_set, DataSet)
            data = self.data_set.data_dict
            data[pedia.features] = self.data_set.features
            if self.data_set.targets is not None:
                data[pedia.targets] = self.data_set.targets
            console.show_status('Data:', '::')
            for k, v in data.items():
                if not hasattr(v, 'shape'): continue
                console.supplement('{}: {}'.format(k, v.shape))
            for k, v in self.data_set.properties.items():
                console.supplement('{}: {}'.format(k, v))
        elif event.keysym == 'space':
            self._resize()
        else:
            # console.show_status(event.keysym)
            pass

        # If needed, refresh image viewer
        if flag: self.refresh()
示例#29
0
def load_wiener_hammerstein(filename,
                            validation_size=20000,
                            test_size=88000,
                            depth=None):
    console.show_status('Loading Wiener-Hammerstein benchmark ...')

    # Load dataset and check input parameters
    dataset = DataSet.load(filename)
    assert isinstance(dataset, DataSet)
    u, y = dataset.signls[0], dataset.responses[0]
    L = u.size
    if validation_size + test_size > L:
        raise ValueError(
            '!! validation_size({}) + test_size({}) > total_size({})'.format(
                validation_size, test_size, L))

    # Separate data
    training_size = L - validation_size - test_size
    train_slice = slice(0, training_size)
    training_set = DataSet(u[train_slice],
                           y[train_slice],
                           memory_depth=depth,
                           name='training set',
                           cut=True)

    val_slice = slice(training_size, training_size + validation_size)
    validation_set = DataSet(u[val_slice],
                             y[val_slice],
                             memory_depth=depth,
                             name='validation set')

    test_slice = slice(L - test_size, L)
    test_set = DataSet(u[test_slice],
                       y[test_slice],
                       memory_depth=depth,
                       name='test set')

    # Show status
    console.show_status('Data set loaded')
    console.supplement('Training set size: {}'.format(
        training_set.signls[0].size))
    console.supplement('Validation set size: {}'.format(
        validation_set.signls[0].size))
    console.supplement('Test set size: {}'.format(test_set.signls[0].size))

    return training_set, validation_set, test_set
示例#30
0
文件: whbm.py 项目: winkywow/tframe
  def evaluate(f, data_set, plot=False):
    if not callable(f): raise AssertionError('!! Input f must be callable')
    checker.check_type(data_set, SignalSet)
    assert isinstance(data_set, SignalSet)
    if data_set.targets is None:
      raise ValueError('!! Responses not found in SignalSet')
    u, y = data_set.features, np.ravel(data_set.targets)
    assert isinstance(y, Signal)
    # Show status
    console.show_status('Evaluating {} ...'.format(data_set.name))
    # In evaluation, the sum of each metric is started at t = 1000 instead of
    #  t = 0 to eliminate the influence of transient errors at the beginning of
    #  the simulation
    start_at = 1000
    model_output = Signal(f(u), fs=y.fs)
    delta = y - model_output
    err = delta[start_at:]
    assert isinstance(err, Signal)
    ratio = lambda val: 100.0 * val / y.rms

    # The mean value of the simulation error in time domain
    val = err.average
    console.supplement('E[err] = {:.4f}mV ({:.3f}%)'.format(
      val * 1000, ratio(val)))
    # The standard deviation of the error in time domain
    val = float(np.std(err))
    console.supplement('STD[err] = {:.4f}mV ({:.3f}%)'.format(
      val * 1000, ratio(val)))
    # The root mean square value of the error in time domain
    val = err.rms
    console.supplement('RMS[err] = {:.4f}mV ({:.3f}%)'.format(
      val * 1000, ratio(val)))

    # Plot
    if not plot: return
    from tframe.data.sequences.signals.figure import Figure, Subplot
    fig = Figure('Simulation Error')
    # Add ground truth
    prefix = 'System Output, $||y|| = {:.4f}$'.format(y.norm)
    fig.add(Subplot.PowerSpectrum(y, prefix=prefix))
    # Add model output
    prefix = 'Model Output, RMS($\Delta$) = ${:.4f}mV$'.format(1000 * err.rms)
    fig.add(Subplot.PowerSpectrum(model_output, prefix=prefix, Error=delta))
    # Plot
    fig.plot()