Exemple #1
0
    def record_stats_on_dataset(self,
                                data_set,
                                slot_scalar_dict,
                                take_down_on_slot=False,
                                rnd=None):
        """
    Currently stats are taken down on instances of class Statistic to
    store metrics on different data set.

    :param data_set: a tframe DataSet
    :param slot_scalar_dict: a dictionary returned by model.validate_model
    :param take_down_on_slot: whether to record stats on metric_slots,
                              usually set to True if data_set is val_set
    :param rnd: if take_down_on_slot, rnd must be provided
    """
        # Sanity check
        # assert isinstance(data_set, DataSet)
        assert isinstance(slot_scalar_dict, dict)

        # Initialize an OrderedDict for data_set if necessary
        if data_set not in self.stats_dict.keys():
            self.stats_dict[data_set] = OrderedDict()

        od = self.stats_dict[data_set]
        flag = False
        assert isinstance(od, OrderedDict)
        for slot, scalar in slot_scalar_dict.items():
            assert isinstance(slot, MetricSlot)
            # Initiate a Statistic for slot on data_set if necessary
            if slot not in od.keys(): od[slot] = Statistic(max_length=2)
            stat = od[slot]
            assert isinstance(stat, Statistic)
            # Record
            stat.record(scalar)
            # Take down if necessary
            if take_down_on_slot:
                assert rnd is not None
                new_record = slot.take_down(scalar, rnd, self.model.counter,
                                            hub.record_gap)
                # Take note for later print
                note_key = (data_set, slot)
                if new_record:
                    self.note[note_key] = '<New Record>'
                    if slot is self.early_stop_slot:
                        flag = True
                        if self.resurrected:
                            self._record_after_resurrection(scalar)
                else:
                    idle = self.idle_counter(slot, rnd)
                    if hub.early_stop and slot is self.early_stop_slot:
                        idle_info = 'Patience {}/{}'.format(
                            idle, self.th.patience)
                    else:
                        idle_info = 'Idle: {}'.format(idle)
                    suffix = '(Best: {}, {})'.format(
                        hub.decimal_str(slot.record, hub.val_decimals),
                        idle_info)
                    self.note[note_key] = suffix

        return flag
Exemple #2
0
    def __init__(
        self,
        model,
        training_set=None,
        validation_set=None,
        snapshot=None,
        probe=None,
        evaluate=None,
        terminator=None,
        test_set=None,
    ):
        # Set model for trainer
        if not isinstance(model, tfr.models.Model):
            raise TypeError('!! model must be an instance of tframe Model')
        self.model = model
        self.model.metrics_manager.trainer = self

        # Date set attributes
        self._training_set = None
        self._validation_set = None
        self._test_set = None
        self.set_data(training_set, validation_set, test_set)

        # Set callable attributes
        self._snapshot_function = checker.check_callable(snapshot)
        self._probe = checker.check_callable(probe)
        self._evaluate = checker.check_callable(evaluate)

        # Initiate trainer hub
        self.th = TrainerHub(self)

        # Private Attributes
        self._record_count = 0
        self._warm_up = True
        self.batch_loss_stat = Statistic(max_length=self.th.hist_buffer_len)

        self.HubClass = TrainerHub
        if terminator is not None: assert callable(terminator)
        self._terminator = terminator

        # Important, since th.lives initialized by shell command will not change
        self._lives = self.th.lives

        # TODO
        # temporary solution to give agent the access to trainer
        context.trainer = self
Exemple #3
0
class Trainer(object):
    """Base class of trainer for training tframe models.

     Model save mechanism when save_mode is
       (1) SaveMode.NAIVE:
           Model will be saved only at the end of each round naively
       (2) SaveMode.ON_RECORD:
           Model will be saved only when a new metric record appears
           after model finishes its warm-up rounds
   """
    HubClass = None

    def __init__(
        self,
        model,
        training_set=None,
        validation_set=None,
        snapshot=None,
        probe=None,
        evaluate=None,
        terminator=None,
        test_set=None,
    ):
        # Set model for trainer
        if not isinstance(model, tfr.models.Model):
            raise TypeError('!! model must be an instance of tframe Model')
        self.model = model
        self.model.metrics_manager.trainer = self

        # Date set attributes
        self._training_set = None
        self._validation_set = None
        self._test_set = None
        self.set_data(training_set, validation_set, test_set)

        # Set callable attributes
        self._snapshot_function = checker.check_callable(snapshot)
        self._probe = checker.check_callable(probe)
        self._evaluate = checker.check_callable(evaluate)

        # Initiate trainer hub
        self.th = TrainerHub(self)

        # Private Attributes
        self._record_count = 0
        self._warm_up = True
        self.batch_loss_stat = Statistic(max_length=self.th.hist_buffer_len)

        self.HubClass = TrainerHub
        if terminator is not None: assert callable(terminator)
        self._terminator = terminator

        # Important, since th.lives initialized by shell command will not change
        self._lives = self.th.lives

        # TODO
        # temporary solution to give agent the access to trainer
        context.trainer = self

    # region : Properties

    @property
    def key_metric(self):
        return self.metrics_manager.early_stop_slot

    @property
    def training_set(self):
        if self._training_set is not None:
            assert isinstance(self._training_set, TFRData)
        return self._training_set

    @property
    def validation_set(self):
        if self._validation_set is not None:
            assert isinstance(self._validation_set, TFRData)
        return self._validation_set

    @property
    def test_set(self):
        if self._test_set is not None:
            assert isinstance(self._test_set, TFRData)
        return self._test_set

    @property
    def is_online(self):
        return isinstance(self.training_set, PerpetualMachine)

    @property
    def session(self):
        session = self.model.session
        assert isinstance(session, tf.Session)
        return session

    @property
    def counter(self):
        return self.model.counter

    @counter.setter
    def counter(self, value):
        self.model.counter = value

    @property
    def metrics_manager(self):
        assert isinstance(self.model.metrics_manager, MetricsManager)
        return self.model.metrics_manager

    @property
    def total_rounds(self):  # TODO: CC
        # TODO: Batch size must be kept the same among different trials
        if self.th.round_length is None: return None
        return self.counter / self.th.round_length

    @property
    def graph(self):
        return self.model.graph

    @property
    def _save_model_when_record_appears(self):
        return (self.th.save_model and self.th.save_mode is SaveMode.ON_RECORD
                and not self._warm_up
                and not (self.th.at_most_save_once_per_round
                         and self._record_count > 1))

    @property
    def _save_model_at_round_end(self):
        return self.th.save_model and self.th.save_mode is SaveMode.NAIVE

    @property
    def _save_model_at_training_end(self):
        return self.th.save_model and self.th.save_model_at_the_end

    # endregion : Properties

    # region : Public Methods

    def set_data(self, training_set=None, validation_set=None, test_set=None):
        if training_set is not None:
            self._check_data(training_set, 'training set')
            self._training_set = training_set
        if validation_set is not None:
            self._check_data(validation_set, 'validation set')
            self._validation_set = validation_set
        if test_set is not None:
            self._check_data(test_set, 'test set')
            self._test_set = test_set

    def recover_progress(self, start_time=None):
        # Print progress bar
        if self.th.progress_bar and self.th.round_length is not None:
            assert isinstance(self._training_set, TFRData)
            progress = self.th.round_progress
            assert progress is not None
            console.print_progress(progress=progress, start_time=start_time)

    # endregion : Public Methods

    # region : Train

    @with_graph
    def train(self, hub=None, **kwargs):
        # Set trainer hub
        self._init_trainer_hub(hub, **kwargs)
        # Run model's pre-train method
        self.model.pretrain(**kwargs)
        # Do some check-up
        self._check_data(), self._sanity_check(), self.th.sanity_check()
        # Check model.session
        self._check_model()
        # Show configurations
        self._show_configurations()
        # Maybe take down some notes
        self._take_notes_before_loops()

        # Train with graph
        with self.session.as_default():
            rounds = self._outer_loop()

        # :: After training
        self._end_training(rounds)
        self._handle_notes()

        # Prune and save if necessary
        if self.th.prune_on: context.pruner.prune_and_save_lottery18()

    # region : Before training

    def _init_trainer_hub(self, hub, **kwargs):
        if hub is not None:
            # If th is provided
            if not isinstance(hub, self.HubClass):
                raise TypeError('!! config must be an instance of {}'.format(
                    self.HubClass))
            self.th = hub
            self.th.trainer = self
        else:
            self.th.set_up(**kwargs)

        # Set progress bar
        if self.th.progress_bar:
            self.th.progress_bar = self.th.round_len_is_active

        # Other setting
        if not self.th.warm_up: self._warm_up = False

    def _sanity_check(self):
        """Should be overrode by subclasses"""
        pass

    def _show_configurations(self):
        console.show_status('Configurations:')
        self.model.agent.take_notes('Configurations:', date_time=False)
        for config in self.th.config_strings:
            console.supplement(config)
            self.model.agent.take_notes('.. {}'.format(config),
                                        date_time=False)

    def _take_notes_before_loops(self):
        if not self.th.export_note: return

    def _check_model(self):
        if not self.model.launched:
            self.model.launch_model(self.th.overwrite)
        # Check model.epoch
        if not self.is_online and self.model.rounds is None:
            self.model.rounds = 0

    # endregion : Before training

    # region : During training

    def _outer_loop(self):
        hub = self.th
        rnd = 0
        for _ in range(hub.total_outer_loops):
            rnd += 1
            if self.is_online: console.section('Iterations Begin')
            else: console.section('{} {}'.format(hub.round_name, rnd))
            hub.tic()

            # Do inner loop
            self._inner_loop(rnd)
            # End of round
            if hub.progress_bar:
                console.show_status(
                    'End of {}. Elapsed time is {:.1f} secs'.format(
                        hub.round_name, hub.toc()))
            # Inc rounds for models training in epochs
            if self.model.rounds is not None:
                self.model.rounds += 1.0
            # Maybe give a report on metric
            if not self.is_online and hub.validation_on:
                self.model.end_round(rnd)
                if self.key_metric.get_idle_rounds(rnd) > self.th.patience:
                    self.th.raise_stop_flag()

            # Maybe save model (model.rounds var has been increased)
            if self._save_model_at_round_end: self._save_model()

            break_flag = False
            # Early stop via stop flag TODO: needed to be unified
            if hub.stop and self.model.bust(rnd): break_flag = True
            # Force terminate
            if hub.force_terminate: break_flag = True
            # Resurrect if possible
            if break_flag and self._lives > 0:
                self.resurrect(rnd)
                if not self.metrics_manager.resurrected:
                    self.metrics_manager.resurrected = True
                    self.metrics_manager.rar0 = self.metrics_manager.early_stop_criterion
                hub.force_terminate = False
                break_flag = False
            # Break if needed to
            if break_flag: break

        # Out of loop
        if hub.gather_note:
            if self.is_online:
                self.model.agent.put_down_criterion('Total Iterations',
                                                    self.counter)
            else:
                self.model.agent.put_down_criterion('Total Rounds', rnd)

        # Put down final weight fraction if etch is on
        if self.th.etch_on:
            frac = context.pruner.weights_fraction
            self.model.agent.take_notes(
                'Final weight fraction: {:.2f}%'.format(frac))
            self.model.agent.put_down_criterion('Weight Fraction', frac)

        # Evaluate the best model if necessary
        ds_dict = OrderedDict()
        if hub.evaluate_train_set: ds_dict['Train'] = self.training_set
        if hub.evaluate_val_set: ds_dict['Val'] = self.validation_set
        if hub.evaluate_test_set: ds_dict['Test'] = self.test_set
        if len(ds_dict) > 0:
            # Load the best model
            if hub.save_model:
                flag, _, _ = self.model.agent.load()
                assert flag
            # Evaluate the specified data sets
            for name, data_set in ds_dict.items():
                if not isinstance(data_set, TFRData):
                    raise TypeError('!! {} set is not a TFRData'.format(name))
                # TODO
                value = self.model.evaluate_model(
                    data_set, batch_size=hub.eval_batch_size)
                title = '{} {}'.format(name,
                                       self.metrics_manager.eval_slot.name)
                self.model.agent.put_down_criterion(title, value)
                self.model.agent.take_notes('{}: {}'.format(
                    title, hub.decimal_str(value, hub.val_decimals)))

        # Save model here if necessary
        if self._save_model_at_training_end:
            assert len(ds_dict) == 0
            self._save_model()

        return rnd

    def _inner_loop(self, rnd):
        self._record_count = 0
        # Begin iteration
        self.th.cursor = 0
        for i, batch in enumerate(self._gen_batches()):
            # Sanity check (make sure sequence batch is equal-length)
            self._check_data_batch(batch)
            # Increase iteration counter
            self.th.cursor += 1
            self.counter += 1
            # Update model
            loss_dict = self._update_model(batch)
            # Print progress
            self._print_progress(rnd, loss_dict)
            # Validation
            if self._validate_model(
                    rnd) and self._save_model_when_record_appears:
                if not self.is_online:
                    assert np.isscalar(self.th.round_progress)
                self._save_model(inter_cut=True,
                                 progress=self.th.round_progress)
            # Etch
            self._etch()
            # Probe
            self._run_probe()
            # Take notes
            self._take_notes_for_export()

            # Check early stop condition
            if self.is_online:
                if self.th.max_iterations is not None:
                    if i + 1 >= self.th.max_iterations:
                        self.th.force_terminate = True
                if self.th.early_stop:
                    if self.key_metric.get_idle_counts(
                            self.counter) > self.th.patience:
                        self.th.force_terminate = True
            # After probing, training process may be terminated
            if self.th.force_terminate:
                # If model will be resurrected later, dynamic_round_len if train_set
                # should be set to None. Otherwise error may occur TODO
                if hasattr(self.training_set, '_clear_dynamic_round_len'):
                    # Perpetual Machine does not have this method
                    self.training_set._clear_dynamic_round_len()
                break
        # Check warm up logic
        if self._warm_up and self._record_count < self.th.warm_up_thres:
            self._warm_up = False

    def resurrect(self, rnd):
        # Decrease lives by 1 and show status
        assert self._lives > 0
        self._lives -= 1
        console.show_status('Lives decreased to {}'.format(self._lives),
                            '[Resurrect]')
        console.show_status('Resurrecting ...')
        # [Compromise] set record counter or round
        self.key_metric.set_record_counter(self.counter)
        self.key_metric.set_record_round(rnd)
        # Load model
        flag, _, _ = self.model.agent.load()
        assert flag
        # Decay learning rate if necessary
        if self.th.lr_decay < 1.0:
            assert self.th.lr_decay > 0
            self.th.clip_lr_multiplier *= self.th.lr_decay
            if self.th.reset_optimizer_after_resurrection:
                self.model.reset_optimizer()
            self.model.set_train_step()
            console.show_status('Learning rate decayed to {:.6f}'.format(
                self.th.learning_rate * self.th.clip_lr_multiplier))

    # endregion : During training

    # region : After training

    def _end_training(self, rounds):
        if self.th.progress_bar: console.clear_line()
        # If this is a hp-tuning task, write record summary
        if self.th.hp_tuning:
            assert not self.th.summary
            self.key_metric.write_record_summary()
        # Flush summary
        if self.th.summary or self.th.hp_tuning:
            self.model.agent.summary_writer.flush()
        # Take notes
        if self.is_online:
            self.model.agent.take_notes(
                'End training after {} iterations'.format(self.counter))
        else:
            total_round = ('' if self.total_rounds is None else
                           ' ({:.1f} total)'.format(self.total_rounds))
            self.model.agent.take_notes(
                'End training after {} rounds{}'.format(rounds, total_round))
        # Evaluate
        if self._evaluate is not None:
            # Load the best model if necessary
            if self.th.save_model:
                flag, _, _ = self.model.agent.load()
                assert flag
            # Evaluate model
            self._evaluate(self)
        # Show RAS if necessary
        if self.th.lives > 0:
            ras_info = self.metrics_manager.RAR_string
            console.show_status(ras_info)
            self.model.agent.take_notes(ras_info)

    def _handle_notes(self):
        # Add metric info into notes
        if self.th.validation_on: self.model.take_down_metric(self.is_online)
        # Put down key configurations to note
        self.model.agent.put_down_configs(self.th)
        # Show notes
        self.model.agent.show_notes()
        # Export notes if necessary
        if self.th.export_note:
            self.model.agent.export_notes()
        # Gather notes if necessary
        if self.th.gather_note:
            self.model.agent.gather_notes()

    # endregion : After training

    # endregion : Train

    # region : Private Methods

    def _update_model(self, data_batch):
        loss_dict = self.model.update_model(data_batch=data_batch)
        loss_slots = [s for s in loss_dict.keys() if s.name == 'Loss']
        # assert len(loss_slots) > 0
        assert len(loss_slots) == 1
        loss_slot = loss_slots[0]
        self.batch_loss_stat.record(loss_dict[loss_slot])

        # Record grads if necessary
        # <monitor_grad_step_03: fetch and record>
        if self.th.monitor_weight_grads:
            grads = loss_dict.pop(self.model.grads_slot)
            context.monitor.record_grads(grads)

        # Record other tensors
        if self.model.general_tensor_slot.activated:
            tensors = loss_dict.pop(self.model.general_tensor_slot)
            context.monitor.record_tensors(tensors)

        # Check NaN
        if self.th.terminate_on_nan:
            for val in loss_dict.values():
                if np.isnan(val):
                    msg = 'Forced termination triggered due to NAN in loss_dict'
                    console.show_status(msg)
                    self.model.agent.take_notes(msg)
                    self.th.force_terminate = True
                    break

        return loss_dict

    def _check_data(self, data_set=None, name='dataset'):
        if data_set is None:
            data_set = self._training_set
            name = 'training set'
        if data_set is None: raise ValueError('!! {} not found'.format(name))
        if not isinstance(data_set, TFRData):
            raise TypeError(
                '!! {} must be an instance of TFRData'.format(name))

    @staticmethod
    def _check_callable(f, name=None):
        if name is None: return
        if f is not None and not callable(f):
            raise TypeError('!! {} must be callable'.format(name))
        return f

    def _gen_batches(self):
        """This method will be called only in the inner loop of train process."""
        if isinstance(self.training_set, SequenceSet):
            # TODO: for now a batch consists of sequences with different lengths can
            #  not be used for training for the padded 0s may produce inappropriate
            #  gradients.
            # if (self.th.batch_size > 1 and not self.training_set.parallel_on and
            #     self.training_set.batch_preprocessor is None):
            #   # a batch of equal-length sequences is allowed
            #   raise AssertionError('!! parallel engine is not activated')

            pass
        return self.model.get_data_batches(self.training_set,
                                           self.th.batch_size,
                                           self.th.num_steps,
                                           self.th.shuffle,
                                           is_training=True)

    @staticmethod
    def _check_data_batch(batch):
        assert isinstance(batch, DataSet)
        # The constraint below is not necessary due to gather_indices mechanism
        # if batch.is_rnn_input and batch.active_length is not None:
        #   if max(batch.active_length) > min(batch.active_length):
        #     raise ValueError('!! Sequence batches must be equal-length')

    def _advanced_strategy(self, rnd):
        """Should be overridden"""
        pass

    def _inter_cut(self, content, prompt='>>', start_time=None):
        # Show content
        console.show_status(content, symbol=prompt)
        # Print progress bar
        self.recover_progress(start_time)

    @staticmethod
    def _dict_to_string(dict_):
        assert isinstance(dict_, dict)
        string_array = ['{} = {:.3f}'.format(k, v) for k, v in dict_.items()]
        return ', '.join(string_array)

    def _print_progress(self, rnd, loss_dict):
        if loss_dict is None or self.th.print_cycle == 0: return
        if np.mod(self.counter - 1, self.th.print_cycle) != 0: return

        loss_string = self._dict_to_string(loss_dict)
        total_rounds = (' - ' if self.total_rounds is None else
                        ' ({:.1f} Total) '.format(self.total_rounds))
        if not self.is_online:
            content = '{} {}{}{}'.format(self.th.round_name, rnd, total_rounds,
                                         loss_string)
        else:
            content = 'Iteration {} - {}'.format(self.counter, loss_string)
        self._inter_cut(content,
                        prompt='[Train]',
                        start_time=self.th.start_time)

    def _get_tensors_to_export(self):
        """For now only RNN dynamics are tracked"""
        from tframe.models.recurrent import Recurrent
        from tframe.models.feedforward import Feedforward

        # This method is based on validation set
        if not self.th.validation_on: return OrderedDict()

        if self.model.input_type is InputTypes.RNN_BATCH:
            tensor_dict = Recurrent.get_tensor_to_export(self)
        else:
            tensor_dict = Feedforward.get_tensor_to_export(self)

        # Add variables to export
        self._get_variables_to_export(tensor_dict)

        return tensor_dict

    def _get_variables_to_export(self, tensor_dict):
        if tensor_dict is None: tensor_dict = OrderedDict()
        assert isinstance(tensor_dict, dict)

        base_on_exemplars = len(tensor_dict) > 0

        # Compromise to avoid widget conflict in tensor_viewer
        def _add_to_dict(key, value):
            if base_on_exemplars:
                for exemplar_dict in tensor_dict.values():
                    exemplar_dict[key] = value
            else:
                tensor_dict[key] = value

        # Add variables to export
        v_fetches_dict = context.variables_to_export
        if len(v_fetches_dict) > 0:
            results = self.model.agent.session.run(
                list(v_fetches_dict.values()))
            for key, value in zip(v_fetches_dict.keys(), results):
                _add_to_dict(key, value)

        # :: Add numpy arrays that stored in monitor
        # Add grads stats if necessary
        if self.th.export_weight_grads:
            for key, value in context.monitor.grad_dict.items():
                _add_to_dict(key, value)

        # Add general stats
        if self.th.export_activations:
            for key, value in context.monitor.stats_dict.items():
                _add_to_dict(key, value)

        return tensor_dict

    def _take_notes_for_export(self):
        # Note switch should be turned on
        if self.th.note_modulus == 0: return
        # Note cycle should be met
        if np.mod(self.counter, self.th.note_modulus) != 0:
            if not (self.counter == 1 and self.th.take_note_in_beginning):
                return
        # Loss history should not be blank
        if not self.batch_loss_stat.last_value: return
        # Validation history should no be blank if validation is on
        if self.th.validation_on:
            if not self.metrics_manager.ready_for_note_taking: return

        # - Scalars
        scalars = OrderedDict()
        scalars['Loss'] = self.batch_loss_stat.running_average
        self.metrics_manager.update_scalar_dict(scalars)

        # - Tensors
        tensors = self._get_tensors_to_export()
        # Take down
        self.model.agent.take_down_scalars_and_tensors(scalars,
                                                       tensors=tensors)
        self._inter_cut('Notes taken down.', prompt='[Export]')
        # For quickly note taking
        if self.th.terminate_on_note: self.th.force_terminate = True

    def _run_probe(self):
        if self._probe is None or self.th.probe_modulus == 0: return False
        if np.mod(self.counter, self.th.probe_modulus) != 0: return False
        # content = self._probe(self, loss_dict=loss_dict)
        content = self._probe(self)
        if content is None or content == '': return
        self._inter_cut(content,
                        prompt='[Probe]',
                        start_time=self.th.start_time)

    def _etch(self):
        if not self.th.etch_on: return
        if np.mod(self.counter, self.th.etch_modulus) != 0: return
        pruner = context.pruner
        assert pruner is not None
        pruner.etch_all()

    def _validate_model(self, rnd):
        if not self.th.validation_on: return False
        # Validate cycle should be met
        if np.mod(self.counter, self.th.validate_modulus) != 0:
            if not (self.counter == 1 and self.th.take_note_in_beginning):
                return False

        # Validate training set if necessary
        if self.th.validate_train_set:
            train_dict = self.model.validate_model(
                self.training_set,
                self.th.val_batch_size,
                allow_sum=False,
                verbose=self.th.val_progress_bar)
            # Record
            self.metrics_manager.record_stats_on_dataset(
                self.training_set, train_dict)

        # Validate val_set and record
        val_dict = self.model.validate_model(
            self.validation_set,
            self.th.val_batch_size,
            allow_sum=self.th.summary,
            verbose=self.th.val_progress_bar,
            seq_detail=self.th.val_info_splits > 0)
        new_record = self.metrics_manager.record_stats_on_dataset(
            self.validation_set, val_dict, True, rnd)
        # Terminator will check early_stop_criterion if new_record appears
        if new_record and callable(self._terminator):
            if self._terminator(self.metrics_manager.early_stop_criterion):
                self.th.force_terminate = True

        # Validate test set if necessary TODO: BETA
        if self.th.validate_test_set:
            test_dict = self.model.validate_model(
                self.test_set,
                self.th.val_batch_size,
                allow_sum=False,
                verbose=self.th.val_progress_bar)
            # Record
            self.metrics_manager.record_stats_on_dataset(
                self.test_set, test_dict)

        # Print stats and return new_record flag
        self.metrics_manager.print_latest_stats('[Validate]',
                                                decimals=self.th.val_decimals)
        return new_record

    def _snapshot(self):
        if not self.th.snapshot: return
        if not self.th.snapshot_cycle > 0: return
        if np.mod(self.counter - 1, self.th.snapshot_cycle) != 0: return

        fig = self._snapshot_function(self.model)
        step = self.counter if self.total_rounds is None else self.total_rounds
        filename = 'train_{:.2f}.png'.format(step)
        self.model.agent.save_plot(fig, filename)
        self._inter_cut("Images saved to '{}'".format(filename), '[Snapshot]')

    def _save_model(self, inter_cut=False, progress=None):
        # Update model rounds
        total_rounds = None
        if not self.is_online:
            assert np.isscalar(self.model.rounds)
            total_rounds = self.model.rounds
            if progress is not None:
                assert 0 <= progress <= 1
                total_rounds += progress
        # Save model
        self.model.agent.save_model(rounds=total_rounds, suffix='train')
        # Show status
        print_method = self._inter_cut if inter_cut else console.show_status
        print_method('Model saved')

    # endregion : Private Methods

    # region : Public Methods

    def get_variables_to_export(self, export_dict=None):
        """This api is for customized probe method"""
        return self._get_variables_to_export(export_dict)