예제 #1
0
    def evaluate_model(self, data, batch_size=None, dynamic=False, **kwargs):
        """The word `evaluate` in this method name is different from that in
       `self.evaluate` method. Here only eval_metric will be evaluated and
       the result will be printed on terminal."""
        # Check metric
        if not self.eval_metric.activated:
            raise AssertionError('!! Metric not defined')
        # Do dynamic evaluation if necessary
        if dynamic:
            from tframe.trainers.eval_tools.dynamic_eval import DynamicEvaluator as de
            de.dynamic_evaluate(self, data, kwargs.get('val_set', None),
                                kwargs.get('delay', None))
            return
        # If hub.val_progress_bar is True, this message will be showed in
        #   model.evaluate method
        if not hub.val_progress_bar:
            console.show_status('Evaluating on {} ...'.format(data.name))
        # use val_progress_bar option here temporarily
        result = self.validate_model(
            data, batch_size, allow_sum=False,
            verbose=hub.val_progress_bar)[self.eval_metric]
        console.supplement('{} = {}'.format(
            self.eval_metric.symbol, hub.decimal_str(result,
                                                     hub.val_decimals)))

        return result
예제 #2
0
    def record_stats_on_dataset(self,
                                data_set,
                                slot_scalar_dict,
                                take_down_on_slot=False,
                                rnd=None):
        """
    Currently stats are taken down on instances of class Statistic to
    store metrics on different data set.

    :param data_set: a tframe DataSet
    :param slot_scalar_dict: a dictionary returned by model.validate_model
    :param take_down_on_slot: whether to record stats on metric_slots,
                              usually set to True if data_set is val_set
    :param rnd: if take_down_on_slot, rnd must be provided
    """
        # Sanity check
        # assert isinstance(data_set, DataSet)
        assert isinstance(slot_scalar_dict, dict)

        # Initialize an OrderedDict for data_set if necessary
        if data_set not in self.stats_dict.keys():
            self.stats_dict[data_set] = OrderedDict()

        od = self.stats_dict[data_set]
        flag = False
        assert isinstance(od, OrderedDict)
        for slot, scalar in slot_scalar_dict.items():
            assert isinstance(slot, MetricSlot)
            # Initiate a Statistic for slot on data_set if necessary
            if slot not in od.keys(): od[slot] = Statistic(max_length=2)
            stat = od[slot]
            assert isinstance(stat, Statistic)
            # Record
            stat.record(scalar)
            # Take down if necessary
            if take_down_on_slot:
                assert rnd is not None
                new_record = slot.take_down(scalar, rnd, self.model.counter,
                                            hub.record_gap)
                # Take note for later print
                note_key = (data_set, slot)
                if new_record:
                    self.note[note_key] = '<New Record>'
                    if slot is self.early_stop_slot:
                        flag = True
                        if self.resurrected:
                            self._record_after_resurrection(scalar)
                else:
                    idle = self.idle_counter(slot, rnd)
                    if hub.early_stop and slot is self.early_stop_slot:
                        idle_info = 'Patience {}/{}'.format(
                            idle, self.th.patience)
                    else:
                        idle_info = 'Idle: {}'.format(idle)
                    suffix = '(Best: {}, {})'.format(
                        hub.decimal_str(slot.record, hub.val_decimals),
                        idle_info)
                    self.note[note_key] = suffix

        return flag
예제 #3
0
 def _validate(self, data_set, batch_size=1, num_steps=1000):
     result_dict = self.model.validate_model(
         data_set,
         batch_size=batch_size,
         verbose=True,
         num_steps=num_steps,
         seq_detail=th.val_info_splits > 0)
     for name, val in result_dict.items():
         console.supplement('{} = {}'.format(
             name, th.decimal_str(val, th.val_decimals)))
예제 #4
0
 def _dynamic_eval(self,
                   data_set,
                   lr,
                   lambd,
                   prompt='[Dynamic Evaluation]'):
     assert isinstance(data_set, DataSet)
     # console.show_status('lr = {}, lambd = {}'.format(lr, lambd), prompt)
     console.show_status('', prompt)
     # Reset parameters
     self.optimizer.reset_parameters()
     # Set HP to optimizer
     self.optimizer.set_hyper_parameters(lr, lambd)
     # Do dynamic evaluation
     output = self.model.evaluate(self._dynamic_fetches,
                                  data_set,
                                  batch_size=1,
                                  verbose=True,
                                  num_steps=th.de_num_steps)[0]
     assert isinstance(self._quantity, Quantity)
     metric = self._quantity.apply_np_summ_method(output)
     console.supplement('Dynamic {} = {}'.format(
         self._quantity.name, th.decimal_str(metric, th.val_decimals)))
     return metric
예제 #5
0
    def _calculate_gradient_stats(self):
        # Sanity check
        checker.check_type(th.train_set, DataSet)
        checker.check_positive_integer(th.de_batch_size)
        checker.check_positive_integer(th.de_num_steps)
        self.show_status('Calculating gradient stats on training set ...')

        grad_square = [tf.square(self._grads[var]) for var in self._var_list]
        fetches = grad_square
        if self._metric_quantity is not None:
            assert isinstance(self._metric_quantity, Quantity)
            fetches.append(self._metric_quantity.quantities)

        # Check train_set
        train_set = th.train_set
        if not isinstance(train_set, DataSet):
            raise TypeError(
                '!! th.train_set must be an instance of DataSet but has'
                ' type `{}`'.format(type(train_set)))
        # Truncate train set if necessary
        if th.de_max_batches > 0:
            if isinstance(train_set, SequenceSet): size = th.de_max_batches
            else: size = th.de_batch_size * th.de_num_steps * th.de_max_batches
            train_set = train_set[:size]
            train_set.name = 'train_set[:{}]'.format(size)
            # Show info
            # self.show_status('train_set truncated to de_max_batches({})'.format(size))
        # num_steps = th.eval_num_steps if th.eval_num_steps else th.de_num_steps
        num_steps = th.de_num_steps
        outputs = self._model.evaluate(fetches,
                                       train_set,
                                       batch_size=th.de_batch_size,
                                       num_steps=num_steps,
                                       verbose=True)

        # Show metric on training set if provided
        if self._metric_quantity is not None:
            metric_quantities = outputs.pop(-1)
            metric_val = self._metric_quantity.apply_np_summ_method(
                metric_quantities)
            console.supplement('{} on training set = {}'.format(
                self._metric_quantity.name,
                th.decimal_str(metric_val, th.val_decimals)))

        # Assign mean square grads
        assign_ops = []
        for var, output_list in zip(self._var_list, outputs):
            assert isinstance(output_list, list)
            mean_square = np.mean(output_list, axis=0)
            sqrt_mean_square = np.sqrt(mean_square)
            assign_ops.append(tf.assign(self._sqrt_MS_g[var],
                                        sqrt_mean_square))
        self._model.session.run(assign_ops)

        # After gradient stats have been calculated, save them into disk
        # .. if necessary
        if th.de_save_train_stats:
            th.train_stats_exists = True
            # When th.train_stats_exists is True,
            # .. saver will initiated with _sqrt_MS_g
            self._model.agent.reset_saver()
            self._model.agent.save_model(suffix='DeStat')
            self.show_status('sqrt_MS_g saved to checkpoint')