コード例 #1
0
ファイル: net.py プロジェクト: winkywow/tframe
    def _get_extra_loss(self):
        loss_tensor_list = context.loss_tensor_list
        assert isinstance(loss_tensor_list, list)

        # (1) Add customized losses
        customized_loss = self._get_customized_loss()
        if customized_loss:
            loss_tensor_list += customized_loss

        # (2) Add regularized losses
        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        loss_tensor_list.extend(reg_losses)
        # (2-A) Add global l2 loss
        if hub.global_l2_penalty > 0:
            loss_tensor_list.append(hub.global_l2_penalty * tf.add_n(
                [tf.nn.l2_loss(v) for v in self.var_list if v.trainable]))

        # Add-up all extra losses
        if loss_tensor_list:
            result = tf.add_n(loss_tensor_list, 'extra_loss')
        else:
            result = None

        # Show loss list (usually for debugging)
        if hub.show_extra_loss_info and loss_tensor_list:
            console.show_info('Extra losses:')
            for loss_tensor in loss_tensor_list:
                assert isinstance(loss_tensor, tf.Tensor)
                console.supplement(loss_tensor.name, level=2)
            console.split()
        return result
コード例 #2
0
ファイル: net.py プロジェクト: ssh352/tframe
  def _get_extra_loss(self):
    loss_tensor_list = context.loss_tensor_list
    assert isinstance(loss_tensor_list, list)

    # (1) Add customized losses
    customized_loss = self._get_customized_loss()
    if customized_loss:
      loss_tensor_list += customized_loss

    # (2) Add regularizer losses
    loss_tensor_list += tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)

    # Add-up all extra losses
    if loss_tensor_list:
      result = tf.add_n(loss_tensor_list, 'extra_loss')
    else: result = None

    # Show loss list (usually for debugging)
    if hub.show_extra_loss_info and loss_tensor_list:
      console.show_info('Extra losses:')
      for loss_tensor in loss_tensor_list:
        assert isinstance(loss_tensor, tf.Tensor)
        console.supplement(loss_tensor.name, level=2)
      console.split()
    return result
コード例 #3
0
    def run(self, strategy='grid', rehearsal=False, **kwargs):
        """Run script using the given 'strategy'. This method is compatible with
       old version of tframe script_helper, and should be deprecated in the
       future. """
        # Show section
        console.section('Script Information')
        # Show pot configs
        self._show_dict('Pot Configurations', self.configs)
        # Hyper-parameter info will be showed when scroll is set
        self.configure(**kwargs)
        # Do some auto configuring, e.g., set greater_is_better based on the given
        #   criterion
        self._auto_config()
        self.pot.set_scroll(self.configs.get('strategy', strategy),
                            **self.configs)
        # Show common parameters
        self._show_dict('Common Settings', self.common_parameters)
        # >>>>>>> 2c2bb62db734310d5ab5fa0cb66e970e161ddebc

        # Begin iteration
        for i, hyper_params in enumerate(self.pot.scroll.combinations()):
            # Show hyper-parameters
            console.show_info('Hyper-parameters:')
            for k, v in hyper_params.items():
                console.supplement('{}: {}'.format(k, v), level=2)
            # Run process if not rehearsal
            if rehearsal: continue
            console.split()
            # Export log if necessary
            if self.pot.logging_is_needed: self._export_log()
            # Run
            self._run_process(hyper_params, i)
コード例 #4
0
 def set_hyper_parameters(self, eta, lambd):
     assert isinstance(eta, float) and isinstance(lambd, float)
     assert eta > 0 and lambd > 0
     self._eta, self._lambda = eta, lambd
     # Set eta and lambda to graph
     self._model.session.run([self._set_eta_op, self._set_lambda_op],
                             feed_dict={
                                 self._eta_placeholder: eta,
                                 self._lambda_placeholder: lambd
                             })
     # Show status
     console.show_info('Hyper parameters for dynamic evaluation updated:')
     console.supplement('eta = {}, lambda = {}'.format(eta, lambd))
コード例 #5
0
ファイル: scroll_base.py プロジェクト: winkywow/tframe
 def set_search_space(self, hyper_params):
     assert isinstance(hyper_params, list)
     if len(hyper_params) == 0: return
     # Find appropriate HP type if different types are allowed
     if self.enable_hp_types:
         hyper_params = [hp.seek_myself() for hp in hyper_params]
     # Show hyper-parameters setting
     console.show_info('Hyper Parameters -> {}'.format(self.name))
     for hp in hyper_params:
         assert isinstance(hp, HyperParameter)
         assert isinstance(hp, self.valid_HP_types)
         console.supplement('{}: {}'.format(hp.name, hp.option_str),
                            level=2)
         self.hyper_params[hp.name] = hp
コード例 #6
0
ファイル: fi2010.py プロジェクト: winkywow/tframe
 def _validate_setup2(cls, data_dir, auction, train_set):
     console.show_status('Validating train set ...', '[Setup2-ZScore]')
     assert isinstance(train_set, SequenceSet)
     # Load zscore data set
     zs_set = cls.load_as_tframe_data(data_dir,
                                      auction=auction,
                                      norm_type='zscore',
                                      setup=2,
                                      file_slices=(slice(6, 7), slice(7,
                                                                      9)))
     zs_feature = zs_set.data_dict['raw_data'][0][:, :40]
     feature = np.concatenate(train_set.features, axis=0)
     assert len(zs_feature) == len(feature)
     delta = np.abs(zs_feature - feature)
     assert np.max(delta) < 1e-4
     console.show_info('Validation completed.')
コード例 #7
0
ファイル: fi2010.py プロジェクト: winkywow/tframe
 def _check_targets(cls, data_dir, auction, data_dict):
     console.show_status('Checking targets list ...')
     assert isinstance(data_dict, dict)
     # Load z-score data
     zscore_set = cls.load_as_tframe_data(data_dir,
                                          auction=auction,
                                          norm_type='zscore',
                                          setup=9,
                                          file_slices=(slice(8, 9),
                                                       slice(8, 9)))
     assert isinstance(zscore_set, SequenceSet)
     # Check targets
     horizons = [10, 20, 30, 50, 100]
     for h in horizons:
         lob_targets = np.concatenate(data_dict[h])
         zs_targets = np.concatenate(zscore_set.data_dict[h])
         if not np.equal(lob_targets, zs_targets).all():
             raise AssertionError(
                 'Targets not equal when horizon = {}'.format(h))
     console.show_info('Targets are all correct.')
コード例 #8
0
ファイル: fi2010.py プロジェクト: winkywow/tframe
 def _init_features_and_targets(cls, lob_set, horizon):
     """x = {P_ask[i], V_ask[i], P_bid[i], V_bid[i]}_i=1^10"""
     max_level = checker.check_positive_integer(th.max_level)
     assert 0 < max_level <= 10
     # Initialize features
     features = lob_set.data_dict['raw_data']
     # .. max_level
     features = [array[:, :4 * max_level] for array in features]
     # .. check developer code
     if 'use_log' in th.developer_code:
         for x in features:
             x[:, 1::2] = np.log10(x[:, 1::2] + 1.0)
         console.show_info('log10 applied to features', '++')
     # .. volume only
     if th.volume_only: features = [array[:, 1::2] for array in features]
     # Set features back
     lob_set.features = features
     # Initialize targets
     lob_set.targets = lob_set.data_dict[horizon]
     return lob_set
コード例 #9
0
    def run(self, times=1, save=False, mark='', rehearsal=False):
        if self._sys_runs is not None:
            times = checker.check_positive_integer(self._sys_runs)
            console.show_status('Run # set to {}'.format(times))
        # Set the corresponding flags if save
        if save:
            self.common_parameters['save_model'] = True
        # Show parameters
        self._show_parameters()
        # Begin iteration
        counter = 0
        for run_id in range(times):
            history = []
            for hyper_dict in self._hyper_parameter_dicts():
                # Set counter here
                counter += 1
                # Grand self._add_script_suffix the highest priority
                if self._add_script_suffix is not None:
                    save = self._add_script_suffix
                if save:
                    self.common_parameters['script_suffix'] = '_{}{}'.format(
                        mark, counter)

                params = self._get_all_configs(hyper_dict)
                self._apply_constraints(params)

                params_list = self._get_config_strings(params)
                params_string = ' '.join(params_list)
                if params_string in history: continue
                history.append(params_string)
                console.show_status(
                    'Loading task ...',
                    '[Run {}/{}][{}]'.format(run_id + 1, times, len(history)))
                console.show_info('Hyper-parameters:')
                for k, v in hyper_dict.items():
                    console.supplement('{}: {}'.format(k, v))
                if not rehearsal:
                    call([self._python_cmd, self.module_name] + params_list)
                    print()
            # End of the run
            if rehearsal: return
コード例 #10
0
ファイル: scroll_base.py プロジェクト: winkywow/tframe
    def set_constraints(self, constraints):
        assert isinstance(constraints, dict)
        if len(constraints) == 0: return
        # Show constraints
        console.show_info('Constraints')
        for cond, cons in constraints.items():
            assert isinstance(cond, tuple) and isinstance(cons, dict)
            # Make sure each value in conditions with multi-value has been registered
            # Otherwise ambiguity will arise during constraint applying
            for k, v in cons.items():
                if isinstance(v, (tuple, set, list)):
                    for choice in v:
                        if choice not in self.hyper_params[k].choices:
                            raise AssertionError(
                                '!! Value {} for {} when {} should be registered first.'
                                .format(choice, k, cond))

            # Consider to skip some constraint settings here, i.e., when condition
            #   will never be satisfied

            console.supplement('{}: {}'.format(cond, cons), level=2)
            self.constraints[cond] = cons
コード例 #11
0
ファイル: seq_set.py プロジェクト: garthtrickett/tframe
 def set_rnn_batch_generator(self, generator):
     assert callable(generator)
     setattr(self, self.RNN_BATCH_GENERATOR, generator)
     console.show_info("RNN batch generator set to {}".format(self.name), "++")
コード例 #12
0
ファイル: fi2010.py プロジェクト: winkywow/tframe
 def _check_raw_lob(cls, data_dir, auction, lob_list, raise_err=False):
     console.show_status('Checking LOB list ...')
     # Sanity check
     assert isinstance(auction, bool) and len(lob_list) == 2
     for lob in lob_list:
         assert isinstance(lob, np.ndarray) and lob.shape[1] == 40
     # Calculate stats for normalization
     lob_1_9 = lob_list[0]
     mu, sigma = np.mean(lob_1_9, axis=0), np.std(lob_1_9, axis=0)
     x_min, x_max = np.min(lob_1_9, axis=0), np.max(lob_1_9, axis=0)
     x_deno = x_max - x_min
     # Load z-score data
     zscore_set = cls.load_as_tframe_data(data_dir,
                                          auction=auction,
                                          norm_type='zscore',
                                          setup=9,
                                          file_slices=(slice(8, 9),
                                                       slice(8, 9)))
     assert isinstance(zscore_set, SequenceSet)
     zs_all = np.concatenate(
         [array[:, :40] for array in zscore_set.data_dict['raw_data']],
         axis=0)
     # Load min-max data
     mm_set = cls.load_as_tframe_data(data_dir,
                                      auction=False,
                                      norm_type='minmax',
                                      setup=9,
                                      file_slices=(slice(8, 9), slice(8,
                                                                      9)))
     mm_all = np.concatenate(
         [array[:, :40] for array in mm_set.data_dict['raw_data']], axis=0)
     # Generate lob -> zscore data for validation
     lob_all = np.concatenate(lob_list, axis=0)
     lob_zs_all = (lob_all - mu) / sigma
     # Check error
     max_err = 1e-4
     delta_all = np.abs(lob_zs_all - zs_all)
     if np.max(delta_all) < max_err:
         console.show_info('LOB list is correct.')
         return True
     if raise_err: raise AssertionError
     # Correct LOB using
     console.show_status('Correcting LOB list ...')
     V_errs, P_errs = 0, 0
     bar = ProgressBar(total=len(lob_all))
     for i, j in np.argwhere(delta_all > max_err):
         price_err = j % 2 == 0
         V_errs, P_errs = V_errs + 1 - price_err, P_errs + price_err
         # Find correct value
         val_zs = zs_all[i][j] * sigma[j] + mu[j]
         val_mm = mm_all[i][j] * x_deno[j] + x_min[j]
         zs_mm_err = abs(val_zs - val_mm)
         if zs_mm_err > 0.1:
             raise AssertionError(
                 'In LOB[{}, {}] val_zs = {} while val_mm = {}'.format(
                     i, j, val_zs, val_mm))
         correct_val = val_mm
         if not P_errs:
             correct_val = np.round(val_mm)
             cor_mm_err = abs(correct_val - val_mm)
             if cor_mm_err > 1e-3:
                 raise AssertionError(
                     'In LOB[{}, {}] cor_val = {} while val_mm = {}'.format(
                         i, j, cor_mm_err, val_mm))
         # Correct value in lob_all
         lob_all[i, j] = correct_val
         bar.show(i)
     # Show status after correction
     console.show_status(
         '{} price errors and {} volume errors have been corrected'.format(
             P_errs, V_errs))
     new_lob_list = []
     for s in [len(array) for array in lob_list]:
         day_block, lob_all = np.split(lob_all, [s])
         new_lob_list.append(day_block)
     assert cls._check_raw_lob(data_dir, auction, new_lob_list, True)
     # for i in range(10): lob_list[i] = new_lob_list[i] TODO
     assert False
コード例 #13
0
ファイル: script_helper.py プロジェクト: rscv5/tframe
 def _show_config(name, od):
     assert isinstance(od, OrderedDict)
     if len(od) == 0: return
     console.show_info(name)
     for k, v in od.items():
         console.supplement('{}: {}'.format(k, v), level=2)