def __init__(self, mark=None): # Model mark usually helps to decide the folder name # TODO: need to be refactored self.mark = hub.mark or mark assert mark is not None if hub.prefix is not None: self.mark = hub.prefix + self.mark if hub.suffix is not None: self.mark += hub.suffix if hub.script_suffix is not None: self.mark += hub.script_suffix # TODO: set prune iteration number. # At this time configs conflicts are not smoothed. if hub.prune_on or hub.pruning_rate_fc > 0: self.mark += '_pr{}'.format(hub.pruning_iterations) hub.mark = self.mark # Each model has an agent to deal with some tensorflow stuff self.agent = Agent(self) # Define slots # 2020-6-10 | William | # outputs should be a Group which is more general for error injection # tframe 2.0 should be using such way to describe a Model self._outputs = TensorSlot(self) # Compromising way to enable additional error injection self._forms_for_injection = [] self._metrics_manager = MetricsManager(self) self._validation_summary = SummarySlot(self) self._batch_val_summ = IndependentSummarySlot(self, 'batch_metric_summ') self._loss = TensorSlot(self, 'Loss') self._train_step = OperationSlot(self) self._train_step_summary = SummarySlot(self) self.validate_group = Group( self, self._validation_summary, name='Validate-group') self._update_group = Group( self, self._loss, self._train_step, self._train_step_summary, name='Update-group') # Slots for exporting np values to note self.grads_slot = NestedTensorSlot(self, 'Gradients') self.general_tensor_slot = NestedTensorSlot(self, 'General-Tensor') # Private attributes self._default_net = None # TODO to be removed self._optimizer = None self._built = False self._scheme = None # Public attributes self.counter = None self.rounds = None self.launched = False # Quantities self.loss_quantity = None
def init_monitor(self, model): from tframe.models import Model hub = tfr.context.hub assert isinstance(model, Model) # (2) Post-activation reception with tf.name_scope('Post_Activation'): self._receive_post_activation() # (3) Weight reception with tf.name_scope('Weights'): for weight in self._weight_lounge: self._round_end_summaries.append( self._make_image_summary(tf.abs(weight), self._get_default_name(weight))) # (4) Add gradients of loss with respect to each weight variable with tf.name_scope('Weight_Grads'): if hub.monitor_grad: self._receive_weight_grad(model.loss.tensor) # (*) Wrap and register update_ops for op in self._update_ops: slot = OperationSlot(model) slot.plug(op) model._update_group.add(slot) # Organize round_end_group if len(self._round_end_summaries) > 0: slot = SummarySlot(model) with tf.name_scope('Monitor'): slot.plug(tf.summary.merge(self._round_end_summaries)) self._round_end_group = Group(model, slot)
def _define_train_step(self, optimizer=None, var_list=None): assert len(self._losses) > 0 with tf.name_scope('Optimizer'): if optimizer is None: optimizer = tf.train.AdamOptimizer(1e-4) self._optimizer = optimizer loss_index = 0 var_list = [] for i, net in enumerate(self.children): assert isinstance(net, Net) var_list += net.var_list if net.is_branch or self._inter_type == pedia.fork: slot = OperationSlot( self, name='train_step_{}'.format(loss_index + 1)) slot.plug( optimizer.minimize( loss=self._losses[loss_index].tensor, var_list=var_list)) self._train_steps.append(slot) loss_index += 1 var_list = [] assert len(self._losses) == len(self._train_steps)
def __init__(self, mark=None): # Model mark usually helps to decide the folder name self.mark = hub.mark or mark assert mark is not None # Each model has an agent to deal with some tensorflow stuff self.agent = Agent(self) # Define slots self._outputs = TensorSlot(self) self._metric = Metric(self, 'metric') self._validation_summary = SummarySlot(self) self._batch_val_summ = IndependentSummarySlot(self, 'batch_metric_summ') self._validate_group = Group(self, self._metric, self._validation_summary, name='Validate-group') self._loss = TensorSlot(self, 'Loss') self._train_step = OperationSlot(self) self._train_step_summary = SummarySlot(self) self._update_group = Group(self, self._loss, self._metric, self._train_step, self._train_step_summary, name='Update-group') # Private attributes self._default_net = None self._optimizer = None self._built = False self._scheme = None # Public attributes self.counter = None self.launched = False
class Model(object): """ Base class of [all?] kinds of models built on TensorFlow """ model_name = 'default' def __init__(self, mark=None): # Model mark usually helps to decide the folder name # TODO: need to be refactored self.mark = hub.mark or mark assert mark is not None if hub.prefix is not None: self.mark = hub.prefix + self.mark if hub.suffix is not None: self.mark += hub.suffix if hub.script_suffix is not None: self.mark += hub.script_suffix # TODO: set prune iteration number. # At this time configs conflicts are not smoothed. if hub.prune_on or hub.pruning_rate_fc > 0: self.mark += '_pr{}'.format(hub.pruning_iterations) hub.mark = self.mark # Each model has an agent to deal with some tensorflow stuff self.agent = Agent(self) # Define slots self._outputs = TensorSlot(self) self._metrics_manager = MetricsManager(self) self._validation_summary = SummarySlot(self) self._batch_val_summ = IndependentSummarySlot(self, 'batch_metric_summ') self._loss = TensorSlot(self, 'Loss') self._train_step = OperationSlot(self) self._train_step_summary = SummarySlot(self) self.validate_group = Group(self, self._validation_summary, name='Validate-group') self._update_group = Group(self, self._loss, self._train_step, self._train_step_summary, name='Update-group') # Slots for exporting np values to note self.grads_slot = NestedTensorSlot(self, 'Gradients') self.general_tensor_slot = NestedTensorSlot(self, 'General-Tensor') # Private attributes self._default_net = None # TODO to be removed self._optimizer = None self._built = False self._scheme = None # Public attributes self.counter = None self.rounds = None self.launched = False # Quantities self.loss_quantity = None # region : Properties # region : Accessor @property def affix(self): return 'model' @property def graph(self): return self.agent.graph @property def session(self): return self.agent.session @property def metrics_manager(self): return self._metrics_manager @property def key_metric(self): if not self.metrics_manager.has_metric: return None return self.metrics_manager.early_stop_slot @property def eval_metric(self): if not self.metrics_manager.has_metric: return None return self.metrics_manager.eval_slot @property def outputs(self): assert isinstance(self._outputs, TensorSlot) return self._outputs @property def loss(self): assert isinstance(self._loss, TensorSlot) return self._loss @property def train_step(self): assert isinstance(self._train_step, OperationSlot) return self._train_step @property def built(self): assert isinstance(self._built, bool) return self._built @property def record(self): if not self.key_metric.activated: return None else: return self.key_metric.record @property def variable_to_save(self): """Should be called in with_graph decorator""" vars = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS)) # Remove `do not save` vars vars = [ var for var in vars if var not in tf.get_collection(pedia.do_not_save) ] filter_by_name = lambda key: [ var for var in vars if key not in var.name ] # Remove `train_opt` vars if necessary if not hub.save_train_opt_vars: vars = filter_by_name(pedia.train_opt) vars = filter_by_name('Optimizer') # Remove `dynamic_opt` vars vars = filter_by_name(pedia.dynamic_opt) # Krause optimizer related vars (TODO: need to be refactored) vars = filter_by_name('de_theta0') if not hub.train_stats_exists: vars = filter_by_name('de_sqrt_MS_g') return vars @property def metric_foreach(self): metrics = tf.get_collection(pedia.metric_foreach) assert len(metrics) == 1 return metrics[0] @property def parameters_dict(self): # Fetch all trainable variables trainable_variables = tf.trainable_variables() values = self.session.run(trainable_variables) # Wrap them into a dictionary and return parameters = {} for t, v, in zip(trainable_variables, values): parameters[t.name] = v return parameters # endregion : Accessor # region : Properties to be overrode @property def description(self): return 'No description' @property def input_type(self): return InputTypes.BATCH # endregion : Properties to be overrode # endregion : Properties # region : Building @with_graph def build(self, **kwargs): # Smooth out flags before important actions hub.smooth_out_conflicts() # Initialize pruner if necessary if any([hub.prune_on, hub.weights_mask_on, hub.etch_on]): # import here to prevent circular import (temporarily) from tframe.operators.prune.pruner import Pruner tfr.context.pruner = Pruner(self) # If optimizer if not provided here, try hub.get_optimizer() # this requires that th.optimizer and th.learning_rate have been provided if 'optimizer' not in kwargs: kwargs['optimizer'] = hub.get_optimizer() # Call successor's _build method self._build(**kwargs) # Initialize monitor self._init_monitor() # Set built flag self._built = True # Show build info console.show_status('Model built successfully:') self.agent.take_notes('Model built successfully') self.agent.take_notes('Structure:', date_time=False) # Description may be a model structure description = self.description if not isinstance(description, (tuple, list)): description = [description] for line in description: assert isinstance(line, str) console.supplement(line) self.agent.take_notes(line, date_time=False) # Add metric slot to update group batch_metric = kwargs.get('batch_metric', []) if batch_metric: if not isinstance(batch_metric, (tuple, list)): batch_metric = [batch_metric] for metric_str in batch_metric: assert isinstance(metric_str, str) metric_slot = self.metrics_manager.get_slot_by_name(metric_str) self._update_group.add(metric_slot) # Register eval_metric if provided eval_metric = kwargs.get('eval_metric', None) if eval_metric is not None: assert isinstance(eval_metric, str) self.metrics_manager.register_eval_slot(eval_metric) def _build(self, optimizer=None, **kwargs): """Abstract method, must be implemented in different models Usually touches tensorflow api directly and plug tf ops into tfr slots """ raise NotImplementedError('!! build method not implemented') def _init_monitor(self): pass # TODO # if tfr.monitor.activated: tfr.monitor.init_monitor(self) @with_graph def _define_train_step(self, optimizer=None, var_list=None): """ TODO: should be modified for tframe.optimizer self._train_step will be plugged only here """ if not self._loss.activated: raise AssertionError('!! loss has not been activated yet') with tf.name_scope('Optimizer'): if optimizer is None: optimizer = hub.get_optimizer() console.show_status( 'Optimizer defined in trainer hub initialized.', '++') # TODO: BETA if hub.use_rtrl: raise AssertionError('use_rtrl option has been deprecated') from tframe.optimizers.rtrl_opt import RealTimeOptimizer optimizer = RealTimeOptimizer(self, optimizer) self._optimizer = optimizer self.set_train_step(var_list) def set_train_step(self, var_list=None): self._train_step.plug( self._optimizer.minimize(self._loss.op, var_list=var_list)) def reset_optimizer(self): from tframe.optimizers.clip_opt import GradientClipOptimizer assert isinstance(self._optimizer, GradientClipOptimizer) self.session.run(self._optimizer.reset_tf_optimizer) console.show_status('TensorFlow optimizer has been reset.') def _merge_summaries(self): train_step_summaries = tf.get_collection(pedia.train_step_summaries) validation_summaries = tf.get_collection(pedia.validation_summaries) if len(train_step_summaries) > 0: self._train_step_summary.plug( tf.summary.merge(train_step_summaries)) if len(validation_summaries) > 0: self._validation_summary.plug( tf.summary.merge(validation_summaries)) # endregion : Building # region : Training def pretrain(self, **kwargs): """Method run in early training process, should be overrode""" if self._scheme is not None: assert isinstance(self._scheme, TrainScheme) trial = self._scheme.dequeue() if trial is not None: trial.initialize(self) @with_graph def train(self, training_set, trainer_hub=None, validation_set=None, snapshot=None, probe=None, evaluate=None, terminator=None, test_set=None, **kwargs): if trainer_hub is None: trainer_class = SmartTrainer if hub.smart_train else Trainer else: if not isinstance(trainer_hub, TrainerHub): raise TypeError( '!! Input hub must be an instance of TrainerHub') trainer_class = trainer_hub.trainer_class trainer = trainer_class(self, training_set=training_set, validation_set=validation_set, snapshot=snapshot, probe=probe, evaluate=evaluate, terminator=terminator, test_set=test_set) trainer.train(hub=trainer_hub, **kwargs) def update_model(self, data_batch, **kwargs): """Default model updating method, should be overrode""" feed_dict = self._get_default_feed_dict(data_batch, is_training=True) return self._update_group.run(feed_dict) def get_data_batches(self, data_set, batch_size, num_steps=None, shuffle=False, is_training=False): """ Get batch generator. This method is used both in training and evaluation/validation. It's trivial for FNN models. However, for RNN models, data_set may be (1) a SequenceSet in which the feature is a list of numpy arrays. each represents a sequence and the lengths may vary. e.g. data_set.feature = [ xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx, xxxxxxxxxxxxxxxxxxx, xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx, xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx, ], in which x represents a data point. In this case, batch size can be (a) 1 (by default) (b) larger than 1, active_len mechanism will be used, val_num_steps will be forced to -1 (2) a DataSet consists of a single sequence, shape = [seq_len, *dim] In this case, batch size can be any integer :param data_set: an instance of DataSet or BigData from which data batches will be extracted :param batch_size: if is None, default value will be assigned according to the input type of this model :param num_steps: step number for RNN data batches :param shuffle: whether to shuffle :return: a generator or a list """ # Data set must be an instance of DataSet or BigData assert isinstance(data_set, (DataSet, BigData, PerpetualMachine)) if self.input_type is InputTypes.BATCH: # 1. For FNN, `num_steps` will be ignored, default batch_size is -1 (all) # If batch size is not specified and data is a DataSet, feed it all at # once into model if batch_size is None and isinstance(data_set, DataSet): return [data_set.stack] # Otherwise batch_size must be an positive integer checker.check_positive_integer(batch_size) data_batches = data_set.gen_batches(batch_size, shuffle=shuffle, is_training=is_training) elif self.input_type is InputTypes.RNN_BATCH: # 2. For RNN, default batch_size is 1, default num_steps is -1 (all) # if num_steps is None: num_steps = -1 if batch_size is None: batch_size = 1 if batch_size < 0: batch_size = data_set.size # Cases: # (1) data_set is a DataSet but not a SequenceSet # each data entry in data_dict will be considered as a consecutive # sequence. batch_size and num_steps can be any integer # (2) data_set is a SequenceSet # ---------------+------------------------+-------------------------- # | num_steps = -1 | num_steps != -1 # ---------------+------------------------+-------------------------- # | | # batch_size = 1 | legal for all | legal for all * # | | # ---------------+------------------------+-------------------------- # | train: legal for equal-length sequences since # | act_len logic has not been implemented # batch_size > 1 | for training * # +------------------------+-------------------------- # | val: legal for all | TODO: not supported # ---------------+------------------------+-------------------------- # | * n_to_one must be False # Check batch_size # it's legal for common DataSet to have num_steps > 0 while batch_size > 1 checker.check_positive_integer(batch_size) if batch_size > 1 and isinstance(data_set, SequenceSet): #assert num_steps < 0 # XXXXXXXX # The constraint below is not necessary due to gather_indices mechanism # if is_training and not hub.use_gather_indices: # assert data_set.equal_length pass # Check num_steps checker.check_type(num_steps, int) if num_steps != -1: # partition logic for n_to_one task has not been implemented yet assert not data_set.n_to_one # Generate batches data_batches = data_set.gen_rnn_batches(batch_size, num_steps, shuffle, is_training=is_training) else: raise ValueError('!! Can not resolve input type of this model') return data_batches def validate_model(self, data_set, batch_size=None, allow_sum=False, verbose=False, seq_detail=False, num_steps=None): """Evaluate quantities in validate group of this model :param data_set: a tframe DataSet :param batch_size: if is None or -1, batch_size will be data_set.size :param allow_sum: whether to add tensorflow summaries TODO: to be deprecated :return: a dictionary in which keys are slots (may include loss and metric) and values are scalars corresponding to these slots """ assert isinstance(data_set, DataSet) if num_steps is None: num_steps = hub.val_num_steps # - One-shot validation one_shot = False batch_is_all = batch_size in (-1, None) or batch_size == data_set.size # .. check one-shot qualification # .. .. the code below should be encapsulated if self.input_type is InputTypes.BATCH: # (1) e.g. small model on MNIST, CIFAR-10 if batch_is_all: one_shot = True elif self.input_type is InputTypes.RNN_BATCH: # (2) if isinstance(data_set, SequenceSet): # (2-a) if batch_is_all and num_steps == -1 and data_set.equal_length: # e.g. AP, TO one_shot = True else: # (2-b) assert isinstance(data_set, DataSet) # assert batch_size in (1, -1, None) # TODO # e.g. small model on WHB if num_steps == -1: one_shot = True # .. do one-shot validation if is qualified # .. for RNN models, reset_batch flag of data_set should be set if one_shot: data_set = self._sanity_check_before_use(data_set) if self.input_type is InputTypes.RNN_BATCH: data_set.should_reset_state = True feed_dict = self._get_default_feed_dict(data_set, is_training=False) return self.validate_group.run(feed_dict, allow_sum=allow_sum) # - Otherwise do batch validation tensor_slots = self.validate_group.tensor_slots quantity_defs = [s.quantity_definition for s in tensor_slots] fetches = [q.quantities for q in quantity_defs] values = self.evaluate(fetches, data_set, batch_size, verbose=verbose, num_steps=num_steps) result_dict = OrderedDict() for val, qd, slot in zip(values, quantity_defs, tensor_slots): # Sanity check assert isinstance(qd, Quantity) if self.input_type is InputTypes.BATCH: assert isinstance(val, np.ndarray) and len(val) > 0 else: assert isinstance(val, list) if not data_set.n_to_one: checker.check_type(val, np.ndarray) # Apply np_summ_method on val scalar = qd.apply_np_summ_method(val, seq_detail) # Add summ to results result_dict[slot] = scalar return result_dict def take_down_metric(self, is_online): for metric in self.metrics_manager.metrics: assert isinstance(metric, MetricSlot) if not metric.activated: continue notes = 'Best {}: {:.3f}'.format(metric.symbol, metric.record) # if not is_online: # notes += ', Best {} = {:.3f}'.format(metric.symbol, metric.mean_record) self.agent.take_notes(notes, date_time=False) # Add history into notes if necessary if hub.show_record_history_in_note: self.agent.take_notes(metric.metric_mean_history_str, date_time=False) # Add record and mean record to notes self.agent.put_down_criterion('Best {}'.format(metric.symbol), metric.record) # if not is_online: # self.agent.put_down_criterion('Best E({})', metric.mean_record) def end_round(self, rnd): self.key_metric.end_round(rnd) def bust(self, rnd): if self._scheme is not None: assert isinstance(self._scheme, TrainScheme) trial = self._scheme.dequeue() if trial is not None: trial.initialize(self) return False else: return True return True # endregion : Training # region : Public Methods def handle_structure_detail(self): detail, total_params, dense_total = '', 0, 0 if hasattr(self, 'structure_detail'): detail, total_params, dense_total = self.structure_detail # Maybe take some notes params_str = 'Total params: {}'.format(total_params) hub.total_params = int(total_params) if hub.prune_on: hub.dense_total_params = dense_total hub.weights_fraction = 100.0 * total_params / dense_total params_str += ' ({:.2f}%)'.format(hub.weights_fraction) self.agent.take_notes(params_str) if hub.show_structure_detail: print('.. Structure detail:\n{}'.format(detail)) def get_trainable_variables(self, f=None): if f is None: f = lambda _: True variables = [v for v in tf.trainable_variables() if f(v)] values = self.session.run(variables) variable_dict = OrderedDict() for t, v in zip(variables, values): variable_dict[t.name] = v return variable_dict def tune_lr(self, new_lr=None, coef=1.0): #TODO if self._optimizer is None: raise ValueError('!! Optimizer not defined yet') if self._optimizer.__class__ in [tf.train.AdamOptimizer]: lr_name = '_lr' elif self._optimizer.__class__ in [tf.train.GradientDescentOptimizer]: lr_name = '_learning_rate' else: raise TypeError('!! Unsupported optimizer for lr tuning') old_lr = self._optimizer.__getattribute__(lr_name) if new_lr is None: new_lr = old_lr * coef self._optimizer.__setattr__(lr_name, new_lr) # Show status console.show_status('Learning rate updated: {:.2e} => {:.2e}'.format( old_lr, new_lr)) return new_lr def set_scheme(self, scheme): if not isinstance(scheme, TrainScheme): raise TypeError('!! scheme must be an instance of TrainScheme') self._scheme = scheme def shutdown(self): self.agent.shutdown() def launch_model(self, overwrite=False): return self.agent.launch_model(overwrite) def evaluate(self, fetches, data, batch_size=None, postprocessor=None, verbose=False, num_steps=None, suppress_n_to_one=False): """ Evaluate tensors based on data TODO: note that if num_steps != -1, outputs from a same sequence may be partitioned. e.g., if single_fetch, outputs will be [array_1_1, ..., array_1_k1, array_2_1, ..., array_2_k2, ...] |-------- input_1 ----------|------------ input_2 ----------| it's OK for seq2seq validation, but need to be post-proceeded in tasks like sequence classification (currently forbidden) :param fetches: a (tuple/list of) tf.Tensor(s) to be evaluated :param data: data used for evaluation :param batch_size: if not specified (None by default), batch_size will be assigned accordingly. If assigned with a positive integer, evaluation will be performed batch by batch. :param postprocessor: post-processor for outputs :return: commonly a (list of) tf.Tensor(s), each of which has the same batch size with the provided data """ # Sanity check for fetches checker.check_fetchable(fetches) single_fetch = not isinstance(fetches, (tuple, list)) # Wrap fetches into a list if necessary if single_fetch: fetches = [fetches] if num_steps is None: num_steps = hub.val_num_steps if batch_size is None: batch_size = data.size # Get outputs (sometimes fetches may contain operations which yields None) outputs = [[] for op in fetches if not isinstance(op, tf.Operation)] if verbose: bar = ProgressBar(data.get_round_length(batch_size, num_steps)) console.show_status('Evaluating on {} ...'.format(data.name)) for cursor, data_batch in enumerate( self.get_data_batches(data, batch_size, num_steps)): data_batch = self._sanity_check_before_use(data_batch) # Get batch outputs fetches[0] fetches[1] # for FNN, batch_outputs = [np_array_1, np_array_2, ...] # each np_array_k have a same batch_size # for RNN, batch_outputs = [[s1_1, s1_2, ..., s1_N], <= fetches[0] # [s2_1, s2_2, ..., s2_N], ...] <= fetches[1] # N is the batch_size, and each sk_i is a numpy array batch_outputs = self._evaluate_batch( fetches, data_batch, num_steps=num_steps, suppress_n_to_one=suppress_n_to_one) assert isinstance(batch_outputs, list) assert len(batch_outputs) == len(outputs) # Add batch_outputs to outputs accordingly for i, batch_output in enumerate(batch_outputs): assert isinstance(outputs[i], list) output_is_a_batch = fetches[i].shape.as_list()[0] is None if self.input_type is InputTypes.RNN_BATCH and output_is_a_batch: # batch_output is [s1_1, s1_2, ..., s1_N] assert isinstance(batch_output, list) outputs[i] = outputs[i] + batch_output else: # batch_output is a numpy array of length batch_size outputs[i].append(batch_output) # Show progress bar if necessary if verbose: bar.show(cursor + 1) # Merge outputs if necessary if self.input_type is InputTypes.BATCH: outputs = [ np.concatenate(array_list, axis=0) for array_list in outputs ] # Post-proceed and return if postprocessor is not None: assert callable(postprocessor) outputs = postprocessor(outputs) assert isinstance(outputs, list) if single_fetch: outputs = outputs[0] return outputs # endregion : Public Methods # region : Private Methods def _evaluate_batch(self, fetch_list, data_set, **kwargs): raise NotImplementedError @with_graph def _get_default_feed_dict(self, batch, is_training): feed_dict = {} for tensor in tf.get_collection(pedia.default_feed_dict): if 'input' in tensor.name.lower(): feed_dict[tensor] = batch[pedia.features] elif tensor.name.lower() in ('target', 'targets'): # elif 'target' in tensor.name: # TODO: when predict without outputting loss ... if batch.targets is not None: feed_dict[tensor] = batch.targets elif pedia.gather_indices in tensor.name: # TODO: when batch.size is 1, gather_indices is not necessary # However, Quantity will never know the exact batch size feed_dict[tensor] = batch.gather_indices else: name = tensor.name.split('/')[-1].split(':')[0] val = batch.data_dict.get(name, None) if val is not None: feed_dict[tensor] = val feed_dict.update(self.agent.get_status_feed_dict(is_training)) return feed_dict def _sanity_check_before_use(self, data): # Make sure data is legal if not isinstance(data, DataSet): raise TypeError('!! Input data must be an instance of DataSet') # Make sure model has been built if not self.built: raise ValueError('!! Model not built yet') # Make sure model has been launched if not self.launched: self.launch_model(overwrite=False) # Make sure data type matches model input type if self.input_type is InputTypes.RNN_BATCH: data = data.as_rnn_batch else: assert not data.is_rnn_input return data
class Model(object): """ Base class of [all?] kinds of models built on TensorFlow """ model_name = 'default' def __init__(self, mark=None): # Model mark usually helps to decide the folder name self.mark = hub.mark or mark assert mark is not None # Each model has an agent to deal with some tensorflow stuff self.agent = Agent(self) # Define slots self._outputs = TensorSlot(self) self._metric = Metric(self, 'metric') self._validation_summary = SummarySlot(self) self._batch_val_summ = IndependentSummarySlot(self, 'batch_metric_summ') self._validate_group = Group(self, self._metric, self._validation_summary, name='Validate-group') self._loss = TensorSlot(self, 'Loss') self._train_step = OperationSlot(self) self._train_step_summary = SummarySlot(self) self._update_group = Group(self, self._loss, self._metric, self._train_step, self._train_step_summary, name='Update-group') # Private attributes self._default_net = None self._optimizer = None self._built = False self._scheme = None # Public attributes self.counter = None self.launched = False # region : Properties # region : Accessor @property def graph(self): return self.agent.graph @property def session(self): return self.agent.session @property def metric(self): if self._metric is not None: assert isinstance(self._metric, Metric) return self._metric @property def outputs(self): assert isinstance(self._outputs, TensorSlot) return self._outputs @property def loss(self): assert isinstance(self._loss, TensorSlot) return self._loss @property def train_step(self): assert isinstance(self._train_step, OperationSlot) return self._train_step @property def built(self): assert isinstance(self._built, bool) return self._built @property def record(self): if not self.metric.activated: return None else: return self.metric.record @property def variable_to_save(self): """Should be called in with_graph decorator""" vars = (tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS)) return [ var for var in vars if var not in tf.get_collection(pedia.do_not_save) ] # endregion : Accessor # region : Properties to be overrode @property def description(self): return 'No description' @property def input_type(self): return InputTypes.BATCH # endregion : Properties to be overrode # endregion : Properties # region : Building @with_graph def build(self, optimizer=None, **kwargs): # Smooth out flags before important actions hub.smooth_out_conflicts() # self._build(optimizer=optimizer, **kwargs) # Initialize monitor self._init_monitor() # Set built flag self._built = True # Show build info console.show_status('Model built successfully:') description = self.description if not isinstance(description, (tuple, list)): description = [description] for line in description: assert isinstance(line, str) console.supplement(line) # Maybe take some notes self.agent.take_notes('Model built successfully') self.agent.take_notes('Structure:', date_time=False) for line in description: self.agent.take_notes(line, date_time=False) def _build(self, optimizer=None, **kwargs): """Abstract method, must be implemented in different models""" raise NotImplementedError('!! build method not implemented') def _init_monitor(self): if tfr.monitor.activated: tfr.monitor.init_monitor(self) @with_graph def _define_train_step(self, optimizer=None, var_list=None): if not self._loss.activated: raise AssertionError('!! loss has not been activated yet') with tf.name_scope('Optimizer'): if optimizer is None: optimizer = tf.train.AdamOptimizer(1e-4) self._optimizer = optimizer self._train_step.plug( optimizer.minimize(self._loss.op, var_list=var_list)) def _merge_summaries(self): train_step_summaries = tf.get_collection(pedia.train_step_summaries) validation_summaries = tf.get_collection(pedia.validation_summaries) if len(train_step_summaries) > 0: self._train_step_summary.plug( tf.summary.merge(train_step_summaries)) if len(validation_summaries) > 0: self._validation_summary.plug( tf.summary.merge(validation_summaries)) # endregion : Building # region : Training def pretrain(self, **kwargs): """Method run in early training process, should be overrode""" if self._scheme is not None: assert isinstance(self._scheme, TrainScheme) trial = self._scheme.dequeue() if trial is not None: trial.initialize(self) @with_graph def train(self, training_set, trainer_hub=None, validation_set=None, snapshot=None, probe=None, **kwargs): if trainer_hub is None: trainer_class = SmartTrainer if hub.smart_train else Trainer else: if not isinstance(trainer_hub, TrainerHub): raise TypeError( '!! Input hub must be an instance of TrainerHub') trainer_class = trainer_hub.trainer_class trainer = trainer_class(self, training_set=training_set, validation_set=validation_set, snapshot=snapshot, probe=probe) trainer.train(hub=trainer_hub, **kwargs) def update_model(self, data_batch, **kwargs): """Default model updating method, should be overrode""" feed_dict = self._get_default_feed_dict(data_batch, is_training=True) return self._update_group.run(feed_dict) def get_data_batches(self, data_set, batch_size, num_steps=None, shuffle=False): """ Get batch generator. :param data_set: an instance of DataSet or BigData from which data batches will be extracted :param batch_size: if is None, default value will be assigned according to the input type of this model :param num_steps: step number for RNN data batches :param shuffle: whether to shuffle :return: a generator or a list """ # Data set must be an instance of DataSet or BigData assert isinstance(data_set, (DataSet, BigData)) if self.input_type is InputTypes.BATCH: # If model's input type is normal batch, num_steps will be ignored # If batch size is not specified and data is a DataSet, feed it all at # once into model if batch_size is None and isinstance(data_set, DataSet): return [data_set.stack] checker.check_positive_integer(batch_size) data_batches = data_set.gen_batches(batch_size, shuffle=shuffle) elif self.input_type is InputTypes.RNN_BATCH: if batch_size is None: batch_size = 1 if num_steps is None: num_steps = -1 checker.check_positive_integer(batch_size) checker.check_type(num_steps, int) data_batches = data_set.gen_rnn_batches(batch_size, num_steps, shuffle) else: raise ValueError('!! Can not resolve input type of this model') return data_batches def validate_model(self, data, batch_size=None, allow_sum=False): """Validate model. If data provided is not regular, batch validation will be used. For RNN model, batch validation requires batch size to be 1.""" assert isinstance(data, TFRData) if not data.is_regular_array and batch_size is None: batch_size = 1 # Normal validation if batch_size is None: data = self._sanity_check_before_use(data) feed_dict = self._get_default_feed_dict(data, is_training=False) return self._validate_group.run(feed_dict, allow_sum=allow_sum) # Batch validation: Calculate metric one by one metric_list = [] total = 0 for batch in self.get_data_batches(data, batch_size, -1, False): # Calculate weight weight = batch.targets.shape[0] if self.input_type is InputTypes.RNN_BATCH: weight *= batch.targets.shape[1] assert weight > 0 total += weight # Validate batch batch = self._sanity_check_before_use(batch) feed_dict = self._get_default_feed_dict(batch, is_training=False) metric_list.append(self._metric.run(feed_dict) * weight) # Return metric mean metric_mean = np.sum(metric_list) / total if allow_sum: self._batch_val_summ.write(metric_mean) return {self._metric: metric_mean} def take_down_metric(self): if not self.metric.activated: return notes = 'Record: {:.3f}, Mean Record: {:.3f}'.format( self.metric.record, self.metric.mean_record) self.agent.take_notes(notes, date_time=False) # TODO # def begin_round(self, **kwargs): # pass def end_round(self, rnd): self.metric.end_round(rnd) def bust(self, rnd): if self._scheme is not None: assert isinstance(self._scheme, TrainScheme) trial = self._scheme.dequeue() if trial is not None: trial.initialize(self) return False else: return True return True # endregion : Training # region : Public Methods def tune_lr(self, new_lr=None, coef=1.0): #TODO if self._optimizer is None: raise ValueError('!! Optimizer not defined yet') if self._optimizer.__class__ in [tf.train.AdamOptimizer]: lr_name = '_lr' elif self._optimizer.__class__ in [tf.train.GradientDescentOptimizer]: lr_name = '_learning_rate' else: raise TypeError('!! Unsupported optimizer for lr tuning') old_lr = self._optimizer.__getattribute__(lr_name) if new_lr is None: new_lr = old_lr * coef self._optimizer.__setattr__(lr_name, new_lr) # Show status console.show_status('Learning rate updated: {:.2e} => {:.2e}'.format( old_lr, new_lr)) return new_lr def set_scheme(self, scheme): if not isinstance(scheme, TrainScheme): raise TypeError('!! scheme must be an instance of TrainScheme') self._scheme = scheme def shutdown(self): self.agent.shutdown() def launch_model(self, overwrite=False): return self.agent.launch_model(overwrite) # endregion : Public Methods # region : Private Methods @with_graph def _get_default_feed_dict(self, batch, is_training): feed_dict = {} for tensor in tf.get_collection(pedia.default_feed_dict): if 'input' in tensor.name.lower(): feed_dict[tensor] = batch[pedia.features] elif 'target' in tensor.name: # TODO: when predict without outputing loss ... if batch.targets is not None: feed_dict[tensor] = batch.targets else: name = tensor.name.split('/')[-1].split(':')[0] val = batch.data_dict.get(name, None) if val is not None: feed_dict[tensor] = val feed_dict.update(self.agent.get_status_feed_dict(is_training)) return feed_dict def _sanity_check_before_use(self, data): if not isinstance(data, DataSet): raise TypeError('!! Input data must be an instance of TFData') if not self.built: raise ValueError('!! Model not built yet') if not self.launched: self.launch_model(overwrite=False) if self.input_type is InputTypes.RNN_BATCH: data = data.as_rnn_data else: assert not data.in_rnn_format return data