예제 #1
0
  def __init__(self, mark=None):
    # Model mark usually helps to decide the folder name
    # TODO: need to be refactored
    self.mark = hub.mark or mark
    assert mark is not None
    if hub.prefix is not None: self.mark = hub.prefix + self.mark
    if hub.suffix is not None: self.mark += hub.suffix
    if hub.script_suffix is not None: self.mark += hub.script_suffix
    # TODO: set prune iteration number.
    #       At this time configs conflicts are not smoothed.
    if hub.prune_on or hub.pruning_rate_fc > 0:
      self.mark += '_pr{}'.format(hub.pruning_iterations)
    hub.mark = self.mark

    # Each model has an agent to deal with some tensorflow stuff
    self.agent = Agent(self)

    # Define slots
    # 2020-6-10 | William |
    #   outputs should be a Group which is more general for error injection
    #   tframe 2.0 should be using such way to describe a Model
    self._outputs = TensorSlot(self)

    # Compromising way to enable additional error injection
    self._forms_for_injection = []

    self._metrics_manager = MetricsManager(self)

    self._validation_summary = SummarySlot(self)
    self._batch_val_summ = IndependentSummarySlot(self, 'batch_metric_summ')

    self._loss = TensorSlot(self, 'Loss')
    self._train_step = OperationSlot(self)
    self._train_step_summary = SummarySlot(self)

    self.validate_group = Group(
      self, self._validation_summary, name='Validate-group')

    self._update_group = Group(
      self, self._loss, self._train_step, self._train_step_summary,
      name='Update-group')

    # Slots for exporting np values to note
    self.grads_slot = NestedTensorSlot(self, 'Gradients')
    self.general_tensor_slot = NestedTensorSlot(self, 'General-Tensor')

    # Private attributes
    self._default_net = None  # TODO to be removed
    self._optimizer = None
    self._built = False
    self._scheme = None

    # Public attributes
    self.counter = None
    self.rounds = None
    self.launched = False

    # Quantities
    self.loss_quantity = None
예제 #2
0
 def init_monitor(self, model):
     from tframe.models import Model
     hub = tfr.context.hub
     assert isinstance(model, Model)
     # (2) Post-activation reception
     with tf.name_scope('Post_Activation'):
         self._receive_post_activation()
     # (3) Weight reception
     with tf.name_scope('Weights'):
         for weight in self._weight_lounge:
             self._round_end_summaries.append(
                 self._make_image_summary(tf.abs(weight),
                                          self._get_default_name(weight)))
     # (4) Add gradients of loss with respect to each weight variable
     with tf.name_scope('Weight_Grads'):
         if hub.monitor_grad: self._receive_weight_grad(model.loss.tensor)
     # (*) Wrap and register update_ops
     for op in self._update_ops:
         slot = OperationSlot(model)
         slot.plug(op)
         model._update_group.add(slot)
     # Organize round_end_group
     if len(self._round_end_summaries) > 0:
         slot = SummarySlot(model)
         with tf.name_scope('Monitor'):
             slot.plug(tf.summary.merge(self._round_end_summaries))
         self._round_end_group = Group(model, slot)
예제 #3
0
 def _define_train_step(self, optimizer=None, var_list=None):
     assert len(self._losses) > 0
     with tf.name_scope('Optimizer'):
         if optimizer is None: optimizer = tf.train.AdamOptimizer(1e-4)
         self._optimizer = optimizer
         loss_index = 0
         var_list = []
         for i, net in enumerate(self.children):
             assert isinstance(net, Net)
             var_list += net.var_list
             if net.is_branch or self._inter_type == pedia.fork:
                 slot = OperationSlot(
                     self, name='train_step_{}'.format(loss_index + 1))
                 slot.plug(
                     optimizer.minimize(
                         loss=self._losses[loss_index].tensor,
                         var_list=var_list))
                 self._train_steps.append(slot)
                 loss_index += 1
                 var_list = []
     assert len(self._losses) == len(self._train_steps)
예제 #4
0
    def __init__(self, mark=None):
        # Model mark usually helps to decide the folder name
        self.mark = hub.mark or mark
        assert mark is not None

        # Each model has an agent to deal with some tensorflow stuff
        self.agent = Agent(self)

        # Define slots
        self._outputs = TensorSlot(self)

        self._metric = Metric(self, 'metric')
        self._validation_summary = SummarySlot(self)
        self._batch_val_summ = IndependentSummarySlot(self,
                                                      'batch_metric_summ')
        self._validate_group = Group(self,
                                     self._metric,
                                     self._validation_summary,
                                     name='Validate-group')

        self._loss = TensorSlot(self, 'Loss')
        self._train_step = OperationSlot(self)
        self._train_step_summary = SummarySlot(self)
        self._update_group = Group(self,
                                   self._loss,
                                   self._metric,
                                   self._train_step,
                                   self._train_step_summary,
                                   name='Update-group')

        # Private attributes
        self._default_net = None
        self._optimizer = None
        self._built = False
        self._scheme = None

        # Public attributes
        self.counter = None
        self.launched = False
예제 #5
0
class Model(object):
    """
  Base class of [all?] kinds of models built on TensorFlow
  """
    model_name = 'default'

    def __init__(self, mark=None):
        # Model mark usually helps to decide the folder name
        # TODO: need to be refactored
        self.mark = hub.mark or mark
        assert mark is not None
        if hub.prefix is not None: self.mark = hub.prefix + self.mark
        if hub.suffix is not None: self.mark += hub.suffix
        if hub.script_suffix is not None: self.mark += hub.script_suffix
        # TODO: set prune iteration number.
        #       At this time configs conflicts are not smoothed.
        if hub.prune_on or hub.pruning_rate_fc > 0:
            self.mark += '_pr{}'.format(hub.pruning_iterations)
        hub.mark = self.mark

        # Each model has an agent to deal with some tensorflow stuff
        self.agent = Agent(self)

        # Define slots
        self._outputs = TensorSlot(self)

        self._metrics_manager = MetricsManager(self)

        self._validation_summary = SummarySlot(self)
        self._batch_val_summ = IndependentSummarySlot(self,
                                                      'batch_metric_summ')

        self._loss = TensorSlot(self, 'Loss')
        self._train_step = OperationSlot(self)
        self._train_step_summary = SummarySlot(self)

        self.validate_group = Group(self,
                                    self._validation_summary,
                                    name='Validate-group')

        self._update_group = Group(self,
                                   self._loss,
                                   self._train_step,
                                   self._train_step_summary,
                                   name='Update-group')

        # Slots for exporting np values to note
        self.grads_slot = NestedTensorSlot(self, 'Gradients')
        self.general_tensor_slot = NestedTensorSlot(self, 'General-Tensor')

        # Private attributes
        self._default_net = None  # TODO to be removed
        self._optimizer = None
        self._built = False
        self._scheme = None

        # Public attributes
        self.counter = None
        self.rounds = None
        self.launched = False

        # Quantities
        self.loss_quantity = None

    # region : Properties

    # region : Accessor

    @property
    def affix(self):
        return 'model'

    @property
    def graph(self):
        return self.agent.graph

    @property
    def session(self):
        return self.agent.session

    @property
    def metrics_manager(self):
        return self._metrics_manager

    @property
    def key_metric(self):
        if not self.metrics_manager.has_metric: return None
        return self.metrics_manager.early_stop_slot

    @property
    def eval_metric(self):
        if not self.metrics_manager.has_metric: return None
        return self.metrics_manager.eval_slot

    @property
    def outputs(self):
        assert isinstance(self._outputs, TensorSlot)
        return self._outputs

    @property
    def loss(self):
        assert isinstance(self._loss, TensorSlot)
        return self._loss

    @property
    def train_step(self):
        assert isinstance(self._train_step, OperationSlot)
        return self._train_step

    @property
    def built(self):
        assert isinstance(self._built, bool)
        return self._built

    @property
    def record(self):
        if not self.key_metric.activated: return None
        else: return self.key_metric.record

    @property
    def variable_to_save(self):
        """Should be called in with_graph decorator"""
        vars = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
                tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
        # Remove `do not save` vars
        vars = [
            var for var in vars
            if var not in tf.get_collection(pedia.do_not_save)
        ]

        filter_by_name = lambda key: [
            var for var in vars if key not in var.name
        ]
        # Remove `train_opt` vars if necessary
        if not hub.save_train_opt_vars:
            vars = filter_by_name(pedia.train_opt)
            vars = filter_by_name('Optimizer')
        # Remove `dynamic_opt` vars
        vars = filter_by_name(pedia.dynamic_opt)
        # Krause optimizer related vars (TODO: need to be refactored)
        vars = filter_by_name('de_theta0')
        if not hub.train_stats_exists: vars = filter_by_name('de_sqrt_MS_g')
        return vars

    @property
    def metric_foreach(self):
        metrics = tf.get_collection(pedia.metric_foreach)
        assert len(metrics) == 1
        return metrics[0]

    @property
    def parameters_dict(self):
        # Fetch all trainable variables
        trainable_variables = tf.trainable_variables()
        values = self.session.run(trainable_variables)
        # Wrap them into a dictionary and return
        parameters = {}
        for t, v, in zip(trainable_variables, values):
            parameters[t.name] = v
        return parameters

    # endregion : Accessor

    # region : Properties to be overrode

    @property
    def description(self):
        return 'No description'

    @property
    def input_type(self):
        return InputTypes.BATCH

    # endregion : Properties to be overrode

    # endregion : Properties

    # region : Building

    @with_graph
    def build(self, **kwargs):

        # Smooth out flags before important actions
        hub.smooth_out_conflicts()
        # Initialize pruner if necessary
        if any([hub.prune_on, hub.weights_mask_on, hub.etch_on]):
            # import here to prevent circular import (temporarily)
            from tframe.operators.prune.pruner import Pruner
            tfr.context.pruner = Pruner(self)
        # If optimizer if not provided here, try hub.get_optimizer()
        #   this requires that th.optimizer and th.learning_rate have been provided
        if 'optimizer' not in kwargs: kwargs['optimizer'] = hub.get_optimizer()
        # Call successor's _build method
        self._build(**kwargs)
        # Initialize monitor
        self._init_monitor()
        # Set built flag
        self._built = True
        # Show build info
        console.show_status('Model built successfully:')
        self.agent.take_notes('Model built successfully')
        self.agent.take_notes('Structure:', date_time=False)
        # Description may be a model structure
        description = self.description
        if not isinstance(description, (tuple, list)):
            description = [description]
        for line in description:
            assert isinstance(line, str)
            console.supplement(line)
            self.agent.take_notes(line, date_time=False)

        # Add metric slot to update group
        batch_metric = kwargs.get('batch_metric', [])
        if batch_metric:
            if not isinstance(batch_metric, (tuple, list)):
                batch_metric = [batch_metric]
            for metric_str in batch_metric:
                assert isinstance(metric_str, str)
                metric_slot = self.metrics_manager.get_slot_by_name(metric_str)
                self._update_group.add(metric_slot)

        # Register eval_metric if provided
        eval_metric = kwargs.get('eval_metric', None)
        if eval_metric is not None:
            assert isinstance(eval_metric, str)
            self.metrics_manager.register_eval_slot(eval_metric)

    def _build(self, optimizer=None, **kwargs):
        """Abstract method, must be implemented in different models
       Usually touches tensorflow api directly and plug tf ops into tfr slots
    """
        raise NotImplementedError('!! build method not implemented')

    def _init_monitor(self):
        pass
        # TODO
        # if tfr.monitor.activated: tfr.monitor.init_monitor(self)

    @with_graph
    def _define_train_step(self, optimizer=None, var_list=None):
        """ TODO: should be modified for tframe.optimizer
        self._train_step will be plugged only here
    """
        if not self._loss.activated:
            raise AssertionError('!! loss has not been activated yet')
        with tf.name_scope('Optimizer'):
            if optimizer is None:
                optimizer = hub.get_optimizer()
                console.show_status(
                    'Optimizer defined in trainer hub initialized.', '++')

            # TODO: BETA
            if hub.use_rtrl:
                raise AssertionError('use_rtrl option has been deprecated')
                from tframe.optimizers.rtrl_opt import RealTimeOptimizer
                optimizer = RealTimeOptimizer(self, optimizer)

            self._optimizer = optimizer
            self.set_train_step(var_list)

    def set_train_step(self, var_list=None):
        self._train_step.plug(
            self._optimizer.minimize(self._loss.op, var_list=var_list))

    def reset_optimizer(self):
        from tframe.optimizers.clip_opt import GradientClipOptimizer
        assert isinstance(self._optimizer, GradientClipOptimizer)
        self.session.run(self._optimizer.reset_tf_optimizer)
        console.show_status('TensorFlow optimizer has been reset.')

    def _merge_summaries(self):
        train_step_summaries = tf.get_collection(pedia.train_step_summaries)
        validation_summaries = tf.get_collection(pedia.validation_summaries)
        if len(train_step_summaries) > 0:
            self._train_step_summary.plug(
                tf.summary.merge(train_step_summaries))
        if len(validation_summaries) > 0:
            self._validation_summary.plug(
                tf.summary.merge(validation_summaries))

    # endregion : Building

    # region : Training

    def pretrain(self, **kwargs):
        """Method run in early training process, should be overrode"""
        if self._scheme is not None:
            assert isinstance(self._scheme, TrainScheme)
            trial = self._scheme.dequeue()
            if trial is not None: trial.initialize(self)

    @with_graph
    def train(self,
              training_set,
              trainer_hub=None,
              validation_set=None,
              snapshot=None,
              probe=None,
              evaluate=None,
              terminator=None,
              test_set=None,
              **kwargs):
        if trainer_hub is None:
            trainer_class = SmartTrainer if hub.smart_train else Trainer
        else:
            if not isinstance(trainer_hub, TrainerHub):
                raise TypeError(
                    '!! Input hub must be an instance of TrainerHub')
            trainer_class = trainer_hub.trainer_class
        trainer = trainer_class(self,
                                training_set=training_set,
                                validation_set=validation_set,
                                snapshot=snapshot,
                                probe=probe,
                                evaluate=evaluate,
                                terminator=terminator,
                                test_set=test_set)
        trainer.train(hub=trainer_hub, **kwargs)

    def update_model(self, data_batch, **kwargs):
        """Default model updating method, should be overrode"""
        feed_dict = self._get_default_feed_dict(data_batch, is_training=True)
        return self._update_group.run(feed_dict)

    def get_data_batches(self,
                         data_set,
                         batch_size,
                         num_steps=None,
                         shuffle=False,
                         is_training=False):
        """ Get batch generator. This method is used both in training and
        evaluation/validation.

        It's trivial for FNN models. However, for RNN models, data_set may be
        (1) a SequenceSet in which the feature is a list of numpy arrays.
            each represents a sequence and the lengths may vary.
            e.g.
            data_set.feature = [
               xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx,
               xxxxxxxxxxxxxxxxxxx,
               xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx,
               xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx,
            ], in which x represents a data point.
            In this case, batch size can be
            (a) 1 (by default)
            (b) larger than 1, active_len mechanism will be used,
                val_num_steps will be forced to -1
        (2) a DataSet consists of a single sequence, shape = [seq_len, *dim]
            In this case, batch size can be any integer

    :param data_set: an instance of DataSet or BigData from which data batches
                      will be extracted
    :param batch_size: if is None, default value will be assigned according to
                        the input type of this model
    :param num_steps: step number for RNN data batches
    :param shuffle: whether to shuffle
    :return: a generator or a list
    """
        # Data set must be an instance of DataSet or BigData
        assert isinstance(data_set, (DataSet, BigData, PerpetualMachine))

        if self.input_type is InputTypes.BATCH:
            # 1. For FNN, `num_steps` will be ignored, default batch_size is -1 (all)

            # If batch size is not specified and data is a DataSet, feed it all at
            #  once into model
            if batch_size is None and isinstance(data_set, DataSet):
                return [data_set.stack]

            # Otherwise batch_size must be an positive integer
            checker.check_positive_integer(batch_size)
            data_batches = data_set.gen_batches(batch_size,
                                                shuffle=shuffle,
                                                is_training=is_training)

        elif self.input_type is InputTypes.RNN_BATCH:
            # 2. For RNN, default batch_size is 1, default num_steps is -1 (all)
            #
            if num_steps is None: num_steps = -1
            if batch_size is None: batch_size = 1
            if batch_size < 0: batch_size = data_set.size

            # Cases:
            # (1) data_set is a DataSet but not a SequenceSet
            #     each data entry in data_dict will be considered as a consecutive
            #     sequence. batch_size and num_steps can be any integer
            # (2) data_set is a SequenceSet
            #     ---------------+------------------------+--------------------------
            #                    | num_steps = -1         | num_steps != -1
            #     ---------------+------------------------+--------------------------
            #                    |                        |
            #     batch_size = 1 | legal for all          | legal for all *
            #                    |                        |
            #     ---------------+------------------------+--------------------------
            #                    | train: legal for equal-length sequences since
            #                    |        act_len logic has not been implemented
            #     batch_size > 1 |        for training *
            #                    +------------------------+--------------------------
            #                    | val: legal for all     | TODO: not supported
            #     ---------------+------------------------+--------------------------
            #                                             | * n_to_one must be False

            # Check batch_size
            # it's legal for common DataSet to have num_steps > 0 while batch_size > 1
            checker.check_positive_integer(batch_size)
            if batch_size > 1 and isinstance(data_set, SequenceSet):
                #assert num_steps < 0  # XXXXXXXX
                # The constraint below is not necessary due to gather_indices mechanism
                # if is_training and not hub.use_gather_indices:
                #   assert data_set.equal_length
                pass

            # Check num_steps
            checker.check_type(num_steps, int)
            if num_steps != -1:
                # partition logic for n_to_one task has not been implemented yet
                assert not data_set.n_to_one

            # Generate batches
            data_batches = data_set.gen_rnn_batches(batch_size,
                                                    num_steps,
                                                    shuffle,
                                                    is_training=is_training)
        else:
            raise ValueError('!! Can not resolve input type of this model')

        return data_batches

    def validate_model(self,
                       data_set,
                       batch_size=None,
                       allow_sum=False,
                       verbose=False,
                       seq_detail=False,
                       num_steps=None):
        """Evaluate quantities in validate group of this model
    :param data_set: a tframe DataSet
    :param batch_size: if is None or -1, batch_size will be data_set.size
    :param allow_sum: whether to add tensorflow summaries TODO: to be deprecated
    :return: a dictionary in which keys are slots (may include loss and metric)
             and values are scalars corresponding to these slots
    """
        assert isinstance(data_set, DataSet)
        if num_steps is None: num_steps = hub.val_num_steps

        # - One-shot validation
        one_shot = False
        batch_is_all = batch_size in (-1, None) or batch_size == data_set.size
        # .. check one-shot qualification
        # .. .. the code below should be encapsulated
        if self.input_type is InputTypes.BATCH:
            # (1) e.g. small model on MNIST, CIFAR-10
            if batch_is_all: one_shot = True
        elif self.input_type is InputTypes.RNN_BATCH:
            # (2)
            if isinstance(data_set, SequenceSet):
                # (2-a)
                if batch_is_all and num_steps == -1 and data_set.equal_length:
                    # e.g. AP, TO
                    one_shot = True
            else:
                # (2-b)
                assert isinstance(data_set, DataSet)
                # assert batch_size in (1, -1, None)  # TODO
                # e.g. small model on WHB
                if num_steps == -1: one_shot = True

        # .. do one-shot validation if is qualified
        # .. for RNN models, reset_batch flag of data_set should be set
        if one_shot:
            data_set = self._sanity_check_before_use(data_set)
            if self.input_type is InputTypes.RNN_BATCH:
                data_set.should_reset_state = True
            feed_dict = self._get_default_feed_dict(data_set,
                                                    is_training=False)
            return self.validate_group.run(feed_dict, allow_sum=allow_sum)

        # - Otherwise do batch validation
        tensor_slots = self.validate_group.tensor_slots
        quantity_defs = [s.quantity_definition for s in tensor_slots]
        fetches = [q.quantities for q in quantity_defs]
        values = self.evaluate(fetches,
                               data_set,
                               batch_size,
                               verbose=verbose,
                               num_steps=num_steps)
        result_dict = OrderedDict()

        for val, qd, slot in zip(values, quantity_defs, tensor_slots):
            # Sanity check
            assert isinstance(qd, Quantity)
            if self.input_type is InputTypes.BATCH:
                assert isinstance(val, np.ndarray) and len(val) > 0
            else:
                assert isinstance(val, list)
                if not data_set.n_to_one:
                    checker.check_type(val, np.ndarray)
            # Apply np_summ_method on val
            scalar = qd.apply_np_summ_method(val, seq_detail)
            # Add summ to results
            result_dict[slot] = scalar

        return result_dict

    def take_down_metric(self, is_online):
        for metric in self.metrics_manager.metrics:
            assert isinstance(metric, MetricSlot)
            if not metric.activated: continue
            notes = 'Best {}: {:.3f}'.format(metric.symbol, metric.record)
            # if not is_online:
            #   notes += ', Best {} = {:.3f}'.format(metric.symbol, metric.mean_record)
            self.agent.take_notes(notes, date_time=False)

            # Add history into notes if necessary
            if hub.show_record_history_in_note:
                self.agent.take_notes(metric.metric_mean_history_str,
                                      date_time=False)
            # Add record and mean record to notes
            self.agent.put_down_criterion('Best {}'.format(metric.symbol),
                                          metric.record)
            # if not is_online:
            #   self.agent.put_down_criterion('Best E({})', metric.mean_record)

    def end_round(self, rnd):
        self.key_metric.end_round(rnd)

    def bust(self, rnd):
        if self._scheme is not None:
            assert isinstance(self._scheme, TrainScheme)
            trial = self._scheme.dequeue()
            if trial is not None:
                trial.initialize(self)
                return False
            else:
                return True
        return True

    # endregion : Training

    # region : Public Methods

    def handle_structure_detail(self):
        detail, total_params, dense_total = '', 0, 0
        if hasattr(self, 'structure_detail'):
            detail, total_params, dense_total = self.structure_detail
        # Maybe take some notes
        params_str = 'Total params: {}'.format(total_params)
        hub.total_params = int(total_params)
        if hub.prune_on:
            hub.dense_total_params = dense_total
            hub.weights_fraction = 100.0 * total_params / dense_total
            params_str += ' ({:.2f}%)'.format(hub.weights_fraction)
        self.agent.take_notes(params_str)

        if hub.show_structure_detail:
            print('.. Structure detail:\n{}'.format(detail))

    def get_trainable_variables(self, f=None):
        if f is None: f = lambda _: True
        variables = [v for v in tf.trainable_variables() if f(v)]
        values = self.session.run(variables)
        variable_dict = OrderedDict()
        for t, v in zip(variables, values):
            variable_dict[t.name] = v
        return variable_dict

    def tune_lr(self, new_lr=None, coef=1.0):
        #TODO
        if self._optimizer is None:
            raise ValueError('!! Optimizer not defined yet')
        if self._optimizer.__class__ in [tf.train.AdamOptimizer]:
            lr_name = '_lr'
        elif self._optimizer.__class__ in [tf.train.GradientDescentOptimizer]:
            lr_name = '_learning_rate'
        else:
            raise TypeError('!! Unsupported optimizer for lr tuning')

        old_lr = self._optimizer.__getattribute__(lr_name)
        if new_lr is None: new_lr = old_lr * coef
        self._optimizer.__setattr__(lr_name, new_lr)

        # Show status
        console.show_status('Learning rate updated: {:.2e} => {:.2e}'.format(
            old_lr, new_lr))

        return new_lr

    def set_scheme(self, scheme):
        if not isinstance(scheme, TrainScheme):
            raise TypeError('!! scheme must be an instance of TrainScheme')
        self._scheme = scheme

    def shutdown(self):
        self.agent.shutdown()

    def launch_model(self, overwrite=False):
        return self.agent.launch_model(overwrite)

    def evaluate(self,
                 fetches,
                 data,
                 batch_size=None,
                 postprocessor=None,
                 verbose=False,
                 num_steps=None,
                 suppress_n_to_one=False):
        """
    Evaluate tensors based on data
    TODO: note that if num_steps != -1, outputs from a same sequence may be
          partitioned. e.g., if single_fetch, outputs will be
          [array_1_1, ..., array_1_k1, array_2_1, ..., array_2_k2, ...]
         |-------- input_1 ----------|------------ input_2 ----------|
         it's OK for seq2seq validation, but need to be post-proceeded in
         tasks like sequence classification (currently forbidden)

    :param fetches: a (tuple/list of) tf.Tensor(s) to be evaluated
    :param data: data used for evaluation
    :param batch_size: if not specified (None by default), batch_size will be
                       assigned accordingly. If assigned with a positive
                       integer, evaluation will be performed batch by batch.
    :param postprocessor: post-processor for outputs
    :return: commonly a (list of) tf.Tensor(s), each of which has the
             same batch size with the provided data
    """
        # Sanity check for fetches
        checker.check_fetchable(fetches)
        single_fetch = not isinstance(fetches, (tuple, list))
        # Wrap fetches into a list if necessary
        if single_fetch: fetches = [fetches]
        if num_steps is None: num_steps = hub.val_num_steps
        if batch_size is None: batch_size = data.size

        # Get outputs (sometimes fetches may contain operations which yields None)
        outputs = [[] for op in fetches if not isinstance(op, tf.Operation)]

        if verbose:
            bar = ProgressBar(data.get_round_length(batch_size, num_steps))
            console.show_status('Evaluating on {} ...'.format(data.name))

        for cursor, data_batch in enumerate(
                self.get_data_batches(data, batch_size, num_steps)):
            data_batch = self._sanity_check_before_use(data_batch)
            # Get batch outputs          fetches[0]  fetches[1]
            #  for FNN, batch_outputs = [np_array_1, np_array_2, ...]
            #           each np_array_k have a same batch_size
            #  for RNN, batch_outputs = [[s1_1, s1_2, ..., s1_N],       <= fetches[0]
            #                            [s2_1, s2_2, ..., s2_N], ...]  <= fetches[1]
            #           N is the batch_size, and each sk_i is a numpy array
            batch_outputs = self._evaluate_batch(
                fetches,
                data_batch,
                num_steps=num_steps,
                suppress_n_to_one=suppress_n_to_one)
            assert isinstance(batch_outputs, list)
            assert len(batch_outputs) == len(outputs)

            # Add batch_outputs to outputs accordingly
            for i, batch_output in enumerate(batch_outputs):
                assert isinstance(outputs[i], list)
                output_is_a_batch = fetches[i].shape.as_list()[0] is None
                if self.input_type is InputTypes.RNN_BATCH and output_is_a_batch:
                    # batch_output is [s1_1, s1_2, ..., s1_N]
                    assert isinstance(batch_output, list)
                    outputs[i] = outputs[i] + batch_output
                else:
                    # batch_output is a numpy array of length batch_size
                    outputs[i].append(batch_output)

            # Show progress bar if necessary
            if verbose: bar.show(cursor + 1)

        # Merge outputs if necessary
        if self.input_type is InputTypes.BATCH:
            outputs = [
                np.concatenate(array_list, axis=0) for array_list in outputs
            ]

        # Post-proceed and return
        if postprocessor is not None:
            assert callable(postprocessor)
            outputs = postprocessor(outputs)

        assert isinstance(outputs, list)
        if single_fetch: outputs = outputs[0]
        return outputs

    # endregion : Public Methods

    # region : Private Methods

    def _evaluate_batch(self, fetch_list, data_set, **kwargs):
        raise NotImplementedError

    @with_graph
    def _get_default_feed_dict(self, batch, is_training):
        feed_dict = {}
        for tensor in tf.get_collection(pedia.default_feed_dict):
            if 'input' in tensor.name.lower():
                feed_dict[tensor] = batch[pedia.features]
            elif tensor.name.lower() in ('target', 'targets'):
                # elif 'target' in tensor.name:
                # TODO: when predict without outputting loss ...
                if batch.targets is not None: feed_dict[tensor] = batch.targets
            elif pedia.gather_indices in tensor.name:
                # TODO: when  batch.size is 1, gather_indices is not necessary
                #       However, Quantity will never know the exact batch size
                feed_dict[tensor] = batch.gather_indices
            else:
                name = tensor.name.split('/')[-1].split(':')[0]
                val = batch.data_dict.get(name, None)
                if val is not None: feed_dict[tensor] = val

        feed_dict.update(self.agent.get_status_feed_dict(is_training))

        return feed_dict

    def _sanity_check_before_use(self, data):
        # Make sure data is legal
        if not isinstance(data, DataSet):
            raise TypeError('!! Input data must be an instance of DataSet')
        # Make sure model has been built
        if not self.built: raise ValueError('!! Model not built yet')
        # Make sure model has been launched
        if not self.launched: self.launch_model(overwrite=False)
        # Make sure data type matches model input type
        if self.input_type is InputTypes.RNN_BATCH: data = data.as_rnn_batch
        else: assert not data.is_rnn_input
        return data
예제 #6
0
class Model(object):
    """
  Base class of [all?] kinds of models built on TensorFlow
  """
    model_name = 'default'

    def __init__(self, mark=None):
        # Model mark usually helps to decide the folder name
        self.mark = hub.mark or mark
        assert mark is not None

        # Each model has an agent to deal with some tensorflow stuff
        self.agent = Agent(self)

        # Define slots
        self._outputs = TensorSlot(self)

        self._metric = Metric(self, 'metric')
        self._validation_summary = SummarySlot(self)
        self._batch_val_summ = IndependentSummarySlot(self,
                                                      'batch_metric_summ')
        self._validate_group = Group(self,
                                     self._metric,
                                     self._validation_summary,
                                     name='Validate-group')

        self._loss = TensorSlot(self, 'Loss')
        self._train_step = OperationSlot(self)
        self._train_step_summary = SummarySlot(self)
        self._update_group = Group(self,
                                   self._loss,
                                   self._metric,
                                   self._train_step,
                                   self._train_step_summary,
                                   name='Update-group')

        # Private attributes
        self._default_net = None
        self._optimizer = None
        self._built = False
        self._scheme = None

        # Public attributes
        self.counter = None
        self.launched = False

    # region : Properties

    # region : Accessor

    @property
    def graph(self):
        return self.agent.graph

    @property
    def session(self):
        return self.agent.session

    @property
    def metric(self):
        if self._metric is not None:
            assert isinstance(self._metric, Metric)
        return self._metric

    @property
    def outputs(self):
        assert isinstance(self._outputs, TensorSlot)
        return self._outputs

    @property
    def loss(self):
        assert isinstance(self._loss, TensorSlot)
        return self._loss

    @property
    def train_step(self):
        assert isinstance(self._train_step, OperationSlot)
        return self._train_step

    @property
    def built(self):
        assert isinstance(self._built, bool)
        return self._built

    @property
    def record(self):
        if not self.metric.activated: return None
        else: return self.metric.record

    @property
    def variable_to_save(self):
        """Should be called in with_graph decorator"""
        vars = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) +
                tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS))
        return [
            var for var in vars
            if var not in tf.get_collection(pedia.do_not_save)
        ]

    # endregion : Accessor

    # region : Properties to be overrode

    @property
    def description(self):
        return 'No description'

    @property
    def input_type(self):
        return InputTypes.BATCH

    # endregion : Properties to be overrode

    # endregion : Properties

    # region : Building

    @with_graph
    def build(self, optimizer=None, **kwargs):
        # Smooth out flags before important actions
        hub.smooth_out_conflicts()
        #
        self._build(optimizer=optimizer, **kwargs)
        # Initialize monitor
        self._init_monitor()
        # Set built flag
        self._built = True
        # Show build info
        console.show_status('Model built successfully:')
        description = self.description
        if not isinstance(description, (tuple, list)):
            description = [description]
        for line in description:
            assert isinstance(line, str)
            console.supplement(line)
        # Maybe take some notes
        self.agent.take_notes('Model built successfully')
        self.agent.take_notes('Structure:', date_time=False)
        for line in description:
            self.agent.take_notes(line, date_time=False)

    def _build(self, optimizer=None, **kwargs):
        """Abstract method, must be implemented in different models"""
        raise NotImplementedError('!! build method not implemented')

    def _init_monitor(self):
        if tfr.monitor.activated: tfr.monitor.init_monitor(self)

    @with_graph
    def _define_train_step(self, optimizer=None, var_list=None):
        if not self._loss.activated:
            raise AssertionError('!! loss has not been activated yet')
        with tf.name_scope('Optimizer'):
            if optimizer is None: optimizer = tf.train.AdamOptimizer(1e-4)
            self._optimizer = optimizer
            self._train_step.plug(
                optimizer.minimize(self._loss.op, var_list=var_list))

    def _merge_summaries(self):
        train_step_summaries = tf.get_collection(pedia.train_step_summaries)
        validation_summaries = tf.get_collection(pedia.validation_summaries)
        if len(train_step_summaries) > 0:
            self._train_step_summary.plug(
                tf.summary.merge(train_step_summaries))
        if len(validation_summaries) > 0:
            self._validation_summary.plug(
                tf.summary.merge(validation_summaries))

    # endregion : Building

    # region : Training

    def pretrain(self, **kwargs):
        """Method run in early training process, should be overrode"""
        if self._scheme is not None:
            assert isinstance(self._scheme, TrainScheme)
            trial = self._scheme.dequeue()
            if trial is not None: trial.initialize(self)

    @with_graph
    def train(self,
              training_set,
              trainer_hub=None,
              validation_set=None,
              snapshot=None,
              probe=None,
              **kwargs):
        if trainer_hub is None:
            trainer_class = SmartTrainer if hub.smart_train else Trainer
        else:
            if not isinstance(trainer_hub, TrainerHub):
                raise TypeError(
                    '!! Input hub must be an instance of TrainerHub')
            trainer_class = trainer_hub.trainer_class
        trainer = trainer_class(self,
                                training_set=training_set,
                                validation_set=validation_set,
                                snapshot=snapshot,
                                probe=probe)
        trainer.train(hub=trainer_hub, **kwargs)

    def update_model(self, data_batch, **kwargs):
        """Default model updating method, should be overrode"""
        feed_dict = self._get_default_feed_dict(data_batch, is_training=True)
        return self._update_group.run(feed_dict)

    def get_data_batches(self,
                         data_set,
                         batch_size,
                         num_steps=None,
                         shuffle=False):
        """ Get batch generator.
    :param data_set: an instance of DataSet or BigData from which data batches
                      will be extracted
    :param batch_size: if is None, default value will be assigned according to
                        the input type of this model
    :param num_steps: step number for RNN data batches
    :param shuffle: whether to shuffle
    :return: a generator or a list
    """
        # Data set must be an instance of DataSet or BigData
        assert isinstance(data_set, (DataSet, BigData))
        if self.input_type is InputTypes.BATCH:
            # If model's input type is normal batch, num_steps will be ignored
            # If batch size is not specified and data is a DataSet, feed it all at
            #  once into model
            if batch_size is None and isinstance(data_set, DataSet):
                return [data_set.stack]
            checker.check_positive_integer(batch_size)
            data_batches = data_set.gen_batches(batch_size, shuffle=shuffle)
        elif self.input_type is InputTypes.RNN_BATCH:
            if batch_size is None: batch_size = 1
            if num_steps is None: num_steps = -1
            checker.check_positive_integer(batch_size)
            checker.check_type(num_steps, int)
            data_batches = data_set.gen_rnn_batches(batch_size, num_steps,
                                                    shuffle)
        else:
            raise ValueError('!! Can not resolve input type of this model')
        return data_batches

    def validate_model(self, data, batch_size=None, allow_sum=False):
        """Validate model. If data provided is not regular, batch validation will
       be used. For RNN model, batch validation requires batch size to be 1."""
        assert isinstance(data, TFRData)
        if not data.is_regular_array and batch_size is None: batch_size = 1
        # Normal validation
        if batch_size is None:
            data = self._sanity_check_before_use(data)
            feed_dict = self._get_default_feed_dict(data, is_training=False)
            return self._validate_group.run(feed_dict, allow_sum=allow_sum)
        # Batch validation: Calculate metric one by one
        metric_list = []
        total = 0
        for batch in self.get_data_batches(data, batch_size, -1, False):
            # Calculate weight
            weight = batch.targets.shape[0]
            if self.input_type is InputTypes.RNN_BATCH:
                weight *= batch.targets.shape[1]
            assert weight > 0
            total += weight
            # Validate batch
            batch = self._sanity_check_before_use(batch)
            feed_dict = self._get_default_feed_dict(batch, is_training=False)
            metric_list.append(self._metric.run(feed_dict) * weight)
        # Return metric mean
        metric_mean = np.sum(metric_list) / total
        if allow_sum: self._batch_val_summ.write(metric_mean)
        return {self._metric: metric_mean}

    def take_down_metric(self):
        if not self.metric.activated: return
        notes = 'Record: {:.3f}, Mean Record: {:.3f}'.format(
            self.metric.record, self.metric.mean_record)
        self.agent.take_notes(notes, date_time=False)

    # TODO
    # def begin_round(self, **kwargs):
    #   pass

    def end_round(self, rnd):
        self.metric.end_round(rnd)

    def bust(self, rnd):
        if self._scheme is not None:
            assert isinstance(self._scheme, TrainScheme)
            trial = self._scheme.dequeue()
            if trial is not None:
                trial.initialize(self)
                return False
            else:
                return True
        return True

    # endregion : Training

    # region : Public Methods

    def tune_lr(self, new_lr=None, coef=1.0):
        #TODO
        if self._optimizer is None:
            raise ValueError('!! Optimizer not defined yet')
        if self._optimizer.__class__ in [tf.train.AdamOptimizer]:
            lr_name = '_lr'
        elif self._optimizer.__class__ in [tf.train.GradientDescentOptimizer]:
            lr_name = '_learning_rate'
        else:
            raise TypeError('!! Unsupported optimizer for lr tuning')

        old_lr = self._optimizer.__getattribute__(lr_name)
        if new_lr is None: new_lr = old_lr * coef
        self._optimizer.__setattr__(lr_name, new_lr)

        # Show status
        console.show_status('Learning rate updated: {:.2e} => {:.2e}'.format(
            old_lr, new_lr))

        return new_lr

    def set_scheme(self, scheme):
        if not isinstance(scheme, TrainScheme):
            raise TypeError('!! scheme must be an instance of TrainScheme')
        self._scheme = scheme

    def shutdown(self):
        self.agent.shutdown()

    def launch_model(self, overwrite=False):
        return self.agent.launch_model(overwrite)

    # endregion : Public Methods

    # region : Private Methods

    @with_graph
    def _get_default_feed_dict(self, batch, is_training):
        feed_dict = {}
        for tensor in tf.get_collection(pedia.default_feed_dict):
            if 'input' in tensor.name.lower():
                feed_dict[tensor] = batch[pedia.features]
            elif 'target' in tensor.name:
                # TODO: when predict without outputing loss ...
                if batch.targets is not None: feed_dict[tensor] = batch.targets
            else:
                name = tensor.name.split('/')[-1].split(':')[0]
                val = batch.data_dict.get(name, None)
                if val is not None: feed_dict[tensor] = val

        feed_dict.update(self.agent.get_status_feed_dict(is_training))

        return feed_dict

    def _sanity_check_before_use(self, data):
        if not isinstance(data, DataSet):
            raise TypeError('!! Input data must be an instance of TFData')
        if not self.built: raise ValueError('!! Model not built yet')
        if not self.launched: self.launch_model(overwrite=False)
        if self.input_type is InputTypes.RNN_BATCH: data = data.as_rnn_data
        else: assert not data.in_rnn_format
        return data