def record_stats_on_dataset(self, data_set, slot_scalar_dict, take_down_on_slot=False, rnd=None): """ Currently stats are taken down on instances of class Statistic to store metrics on different data set. :param data_set: a tframe DataSet :param slot_scalar_dict: a dictionary returned by model.validate_model :param take_down_on_slot: whether to record stats on metric_slots, usually set to True if data_set is val_set :param rnd: if take_down_on_slot, rnd must be provided """ # Sanity check # assert isinstance(data_set, DataSet) assert isinstance(slot_scalar_dict, dict) # Initialize an OrderedDict for data_set if necessary if data_set not in self.stats_dict.keys(): self.stats_dict[data_set] = OrderedDict() od = self.stats_dict[data_set] flag = False assert isinstance(od, OrderedDict) for slot, scalar in slot_scalar_dict.items(): assert isinstance(slot, MetricSlot) # Initiate a Statistic for slot on data_set if necessary if slot not in od.keys(): od[slot] = Statistic(max_length=2) stat = od[slot] assert isinstance(stat, Statistic) # Record stat.record(scalar) # Take down if necessary if take_down_on_slot: assert rnd is not None new_record = slot.take_down(scalar, rnd, self.model.counter, hub.record_gap) # Take note for later print note_key = (data_set, slot) if new_record: self.note[note_key] = '<New Record>' if slot is self.early_stop_slot: flag = True if self.resurrected: self._record_after_resurrection(scalar) else: idle = self.idle_counter(slot, rnd) if hub.early_stop and slot is self.early_stop_slot: idle_info = 'Patience {}/{}'.format( idle, self.th.patience) else: idle_info = 'Idle: {}'.format(idle) suffix = '(Best: {}, {})'.format( hub.decimal_str(slot.record, hub.val_decimals), idle_info) self.note[note_key] = suffix return flag
def __init__( self, model, training_set=None, validation_set=None, snapshot=None, probe=None, evaluate=None, terminator=None, test_set=None, ): # Set model for trainer if not isinstance(model, tfr.models.Model): raise TypeError('!! model must be an instance of tframe Model') self.model = model self.model.metrics_manager.trainer = self # Date set attributes self._training_set = None self._validation_set = None self._test_set = None self.set_data(training_set, validation_set, test_set) # Set callable attributes self._snapshot_function = checker.check_callable(snapshot) self._probe = checker.check_callable(probe) self._evaluate = checker.check_callable(evaluate) # Initiate trainer hub self.th = TrainerHub(self) # Private Attributes self._record_count = 0 self._warm_up = True self.batch_loss_stat = Statistic(max_length=self.th.hist_buffer_len) self.HubClass = TrainerHub if terminator is not None: assert callable(terminator) self._terminator = terminator # Important, since th.lives initialized by shell command will not change self._lives = self.th.lives # TODO # temporary solution to give agent the access to trainer context.trainer = self
class Trainer(object): """Base class of trainer for training tframe models. Model save mechanism when save_mode is (1) SaveMode.NAIVE: Model will be saved only at the end of each round naively (2) SaveMode.ON_RECORD: Model will be saved only when a new metric record appears after model finishes its warm-up rounds """ HubClass = None def __init__( self, model, training_set=None, validation_set=None, snapshot=None, probe=None, evaluate=None, terminator=None, test_set=None, ): # Set model for trainer if not isinstance(model, tfr.models.Model): raise TypeError('!! model must be an instance of tframe Model') self.model = model self.model.metrics_manager.trainer = self # Date set attributes self._training_set = None self._validation_set = None self._test_set = None self.set_data(training_set, validation_set, test_set) # Set callable attributes self._snapshot_function = checker.check_callable(snapshot) self._probe = checker.check_callable(probe) self._evaluate = checker.check_callable(evaluate) # Initiate trainer hub self.th = TrainerHub(self) # Private Attributes self._record_count = 0 self._warm_up = True self.batch_loss_stat = Statistic(max_length=self.th.hist_buffer_len) self.HubClass = TrainerHub if terminator is not None: assert callable(terminator) self._terminator = terminator # Important, since th.lives initialized by shell command will not change self._lives = self.th.lives # TODO # temporary solution to give agent the access to trainer context.trainer = self # region : Properties @property def key_metric(self): return self.metrics_manager.early_stop_slot @property def training_set(self): if self._training_set is not None: assert isinstance(self._training_set, TFRData) return self._training_set @property def validation_set(self): if self._validation_set is not None: assert isinstance(self._validation_set, TFRData) return self._validation_set @property def test_set(self): if self._test_set is not None: assert isinstance(self._test_set, TFRData) return self._test_set @property def is_online(self): return isinstance(self.training_set, PerpetualMachine) @property def session(self): session = self.model.session assert isinstance(session, tf.Session) return session @property def counter(self): return self.model.counter @counter.setter def counter(self, value): self.model.counter = value @property def metrics_manager(self): assert isinstance(self.model.metrics_manager, MetricsManager) return self.model.metrics_manager @property def total_rounds(self): # TODO: CC # TODO: Batch size must be kept the same among different trials if self.th.round_length is None: return None return self.counter / self.th.round_length @property def graph(self): return self.model.graph @property def _save_model_when_record_appears(self): return (self.th.save_model and self.th.save_mode is SaveMode.ON_RECORD and not self._warm_up and not (self.th.at_most_save_once_per_round and self._record_count > 1)) @property def _save_model_at_round_end(self): return self.th.save_model and self.th.save_mode is SaveMode.NAIVE @property def _save_model_at_training_end(self): return self.th.save_model and self.th.save_model_at_the_end # endregion : Properties # region : Public Methods def set_data(self, training_set=None, validation_set=None, test_set=None): if training_set is not None: self._check_data(training_set, 'training set') self._training_set = training_set if validation_set is not None: self._check_data(validation_set, 'validation set') self._validation_set = validation_set if test_set is not None: self._check_data(test_set, 'test set') self._test_set = test_set def recover_progress(self, start_time=None): # Print progress bar if self.th.progress_bar and self.th.round_length is not None: assert isinstance(self._training_set, TFRData) progress = self.th.round_progress assert progress is not None console.print_progress(progress=progress, start_time=start_time) # endregion : Public Methods # region : Train @with_graph def train(self, hub=None, **kwargs): # Set trainer hub self._init_trainer_hub(hub, **kwargs) # Run model's pre-train method self.model.pretrain(**kwargs) # Do some check-up self._check_data(), self._sanity_check(), self.th.sanity_check() # Check model.session self._check_model() # Show configurations self._show_configurations() # Maybe take down some notes self._take_notes_before_loops() # Train with graph with self.session.as_default(): rounds = self._outer_loop() # :: After training self._end_training(rounds) self._handle_notes() # Prune and save if necessary if self.th.prune_on: context.pruner.prune_and_save_lottery18() # region : Before training def _init_trainer_hub(self, hub, **kwargs): if hub is not None: # If th is provided if not isinstance(hub, self.HubClass): raise TypeError('!! config must be an instance of {}'.format( self.HubClass)) self.th = hub self.th.trainer = self else: self.th.set_up(**kwargs) # Set progress bar if self.th.progress_bar: self.th.progress_bar = self.th.round_len_is_active # Other setting if not self.th.warm_up: self._warm_up = False def _sanity_check(self): """Should be overrode by subclasses""" pass def _show_configurations(self): console.show_status('Configurations:') self.model.agent.take_notes('Configurations:', date_time=False) for config in self.th.config_strings: console.supplement(config) self.model.agent.take_notes('.. {}'.format(config), date_time=False) def _take_notes_before_loops(self): if not self.th.export_note: return def _check_model(self): if not self.model.launched: self.model.launch_model(self.th.overwrite) # Check model.epoch if not self.is_online and self.model.rounds is None: self.model.rounds = 0 # endregion : Before training # region : During training def _outer_loop(self): hub = self.th rnd = 0 for _ in range(hub.total_outer_loops): rnd += 1 if self.is_online: console.section('Iterations Begin') else: console.section('{} {}'.format(hub.round_name, rnd)) hub.tic() # Do inner loop self._inner_loop(rnd) # End of round if hub.progress_bar: console.show_status( 'End of {}. Elapsed time is {:.1f} secs'.format( hub.round_name, hub.toc())) # Inc rounds for models training in epochs if self.model.rounds is not None: self.model.rounds += 1.0 # Maybe give a report on metric if not self.is_online and hub.validation_on: self.model.end_round(rnd) if self.key_metric.get_idle_rounds(rnd) > self.th.patience: self.th.raise_stop_flag() # Maybe save model (model.rounds var has been increased) if self._save_model_at_round_end: self._save_model() break_flag = False # Early stop via stop flag TODO: needed to be unified if hub.stop and self.model.bust(rnd): break_flag = True # Force terminate if hub.force_terminate: break_flag = True # Resurrect if possible if break_flag and self._lives > 0: self.resurrect(rnd) if not self.metrics_manager.resurrected: self.metrics_manager.resurrected = True self.metrics_manager.rar0 = self.metrics_manager.early_stop_criterion hub.force_terminate = False break_flag = False # Break if needed to if break_flag: break # Out of loop if hub.gather_note: if self.is_online: self.model.agent.put_down_criterion('Total Iterations', self.counter) else: self.model.agent.put_down_criterion('Total Rounds', rnd) # Put down final weight fraction if etch is on if self.th.etch_on: frac = context.pruner.weights_fraction self.model.agent.take_notes( 'Final weight fraction: {:.2f}%'.format(frac)) self.model.agent.put_down_criterion('Weight Fraction', frac) # Evaluate the best model if necessary ds_dict = OrderedDict() if hub.evaluate_train_set: ds_dict['Train'] = self.training_set if hub.evaluate_val_set: ds_dict['Val'] = self.validation_set if hub.evaluate_test_set: ds_dict['Test'] = self.test_set if len(ds_dict) > 0: # Load the best model if hub.save_model: flag, _, _ = self.model.agent.load() assert flag # Evaluate the specified data sets for name, data_set in ds_dict.items(): if not isinstance(data_set, TFRData): raise TypeError('!! {} set is not a TFRData'.format(name)) # TODO value = self.model.evaluate_model( data_set, batch_size=hub.eval_batch_size) title = '{} {}'.format(name, self.metrics_manager.eval_slot.name) self.model.agent.put_down_criterion(title, value) self.model.agent.take_notes('{}: {}'.format( title, hub.decimal_str(value, hub.val_decimals))) # Save model here if necessary if self._save_model_at_training_end: assert len(ds_dict) == 0 self._save_model() return rnd def _inner_loop(self, rnd): self._record_count = 0 # Begin iteration self.th.cursor = 0 for i, batch in enumerate(self._gen_batches()): # Sanity check (make sure sequence batch is equal-length) self._check_data_batch(batch) # Increase iteration counter self.th.cursor += 1 self.counter += 1 # Update model loss_dict = self._update_model(batch) # Print progress self._print_progress(rnd, loss_dict) # Validation if self._validate_model( rnd) and self._save_model_when_record_appears: if not self.is_online: assert np.isscalar(self.th.round_progress) self._save_model(inter_cut=True, progress=self.th.round_progress) # Etch self._etch() # Probe self._run_probe() # Take notes self._take_notes_for_export() # Check early stop condition if self.is_online: if self.th.max_iterations is not None: if i + 1 >= self.th.max_iterations: self.th.force_terminate = True if self.th.early_stop: if self.key_metric.get_idle_counts( self.counter) > self.th.patience: self.th.force_terminate = True # After probing, training process may be terminated if self.th.force_terminate: # If model will be resurrected later, dynamic_round_len if train_set # should be set to None. Otherwise error may occur TODO if hasattr(self.training_set, '_clear_dynamic_round_len'): # Perpetual Machine does not have this method self.training_set._clear_dynamic_round_len() break # Check warm up logic if self._warm_up and self._record_count < self.th.warm_up_thres: self._warm_up = False def resurrect(self, rnd): # Decrease lives by 1 and show status assert self._lives > 0 self._lives -= 1 console.show_status('Lives decreased to {}'.format(self._lives), '[Resurrect]') console.show_status('Resurrecting ...') # [Compromise] set record counter or round self.key_metric.set_record_counter(self.counter) self.key_metric.set_record_round(rnd) # Load model flag, _, _ = self.model.agent.load() assert flag # Decay learning rate if necessary if self.th.lr_decay < 1.0: assert self.th.lr_decay > 0 self.th.clip_lr_multiplier *= self.th.lr_decay if self.th.reset_optimizer_after_resurrection: self.model.reset_optimizer() self.model.set_train_step() console.show_status('Learning rate decayed to {:.6f}'.format( self.th.learning_rate * self.th.clip_lr_multiplier)) # endregion : During training # region : After training def _end_training(self, rounds): if self.th.progress_bar: console.clear_line() # If this is a hp-tuning task, write record summary if self.th.hp_tuning: assert not self.th.summary self.key_metric.write_record_summary() # Flush summary if self.th.summary or self.th.hp_tuning: self.model.agent.summary_writer.flush() # Take notes if self.is_online: self.model.agent.take_notes( 'End training after {} iterations'.format(self.counter)) else: total_round = ('' if self.total_rounds is None else ' ({:.1f} total)'.format(self.total_rounds)) self.model.agent.take_notes( 'End training after {} rounds{}'.format(rounds, total_round)) # Evaluate if self._evaluate is not None: # Load the best model if necessary if self.th.save_model: flag, _, _ = self.model.agent.load() assert flag # Evaluate model self._evaluate(self) # Show RAS if necessary if self.th.lives > 0: ras_info = self.metrics_manager.RAR_string console.show_status(ras_info) self.model.agent.take_notes(ras_info) def _handle_notes(self): # Add metric info into notes if self.th.validation_on: self.model.take_down_metric(self.is_online) # Put down key configurations to note self.model.agent.put_down_configs(self.th) # Show notes self.model.agent.show_notes() # Export notes if necessary if self.th.export_note: self.model.agent.export_notes() # Gather notes if necessary if self.th.gather_note: self.model.agent.gather_notes() # endregion : After training # endregion : Train # region : Private Methods def _update_model(self, data_batch): loss_dict = self.model.update_model(data_batch=data_batch) loss_slots = [s for s in loss_dict.keys() if s.name == 'Loss'] # assert len(loss_slots) > 0 assert len(loss_slots) == 1 loss_slot = loss_slots[0] self.batch_loss_stat.record(loss_dict[loss_slot]) # Record grads if necessary # <monitor_grad_step_03: fetch and record> if self.th.monitor_weight_grads: grads = loss_dict.pop(self.model.grads_slot) context.monitor.record_grads(grads) # Record other tensors if self.model.general_tensor_slot.activated: tensors = loss_dict.pop(self.model.general_tensor_slot) context.monitor.record_tensors(tensors) # Check NaN if self.th.terminate_on_nan: for val in loss_dict.values(): if np.isnan(val): msg = 'Forced termination triggered due to NAN in loss_dict' console.show_status(msg) self.model.agent.take_notes(msg) self.th.force_terminate = True break return loss_dict def _check_data(self, data_set=None, name='dataset'): if data_set is None: data_set = self._training_set name = 'training set' if data_set is None: raise ValueError('!! {} not found'.format(name)) if not isinstance(data_set, TFRData): raise TypeError( '!! {} must be an instance of TFRData'.format(name)) @staticmethod def _check_callable(f, name=None): if name is None: return if f is not None and not callable(f): raise TypeError('!! {} must be callable'.format(name)) return f def _gen_batches(self): """This method will be called only in the inner loop of train process.""" if isinstance(self.training_set, SequenceSet): # TODO: for now a batch consists of sequences with different lengths can # not be used for training for the padded 0s may produce inappropriate # gradients. # if (self.th.batch_size > 1 and not self.training_set.parallel_on and # self.training_set.batch_preprocessor is None): # # a batch of equal-length sequences is allowed # raise AssertionError('!! parallel engine is not activated') pass return self.model.get_data_batches(self.training_set, self.th.batch_size, self.th.num_steps, self.th.shuffle, is_training=True) @staticmethod def _check_data_batch(batch): assert isinstance(batch, DataSet) # The constraint below is not necessary due to gather_indices mechanism # if batch.is_rnn_input and batch.active_length is not None: # if max(batch.active_length) > min(batch.active_length): # raise ValueError('!! Sequence batches must be equal-length') def _advanced_strategy(self, rnd): """Should be overridden""" pass def _inter_cut(self, content, prompt='>>', start_time=None): # Show content console.show_status(content, symbol=prompt) # Print progress bar self.recover_progress(start_time) @staticmethod def _dict_to_string(dict_): assert isinstance(dict_, dict) string_array = ['{} = {:.3f}'.format(k, v) for k, v in dict_.items()] return ', '.join(string_array) def _print_progress(self, rnd, loss_dict): if loss_dict is None or self.th.print_cycle == 0: return if np.mod(self.counter - 1, self.th.print_cycle) != 0: return loss_string = self._dict_to_string(loss_dict) total_rounds = (' - ' if self.total_rounds is None else ' ({:.1f} Total) '.format(self.total_rounds)) if not self.is_online: content = '{} {}{}{}'.format(self.th.round_name, rnd, total_rounds, loss_string) else: content = 'Iteration {} - {}'.format(self.counter, loss_string) self._inter_cut(content, prompt='[Train]', start_time=self.th.start_time) def _get_tensors_to_export(self): """For now only RNN dynamics are tracked""" from tframe.models.recurrent import Recurrent from tframe.models.feedforward import Feedforward # This method is based on validation set if not self.th.validation_on: return OrderedDict() if self.model.input_type is InputTypes.RNN_BATCH: tensor_dict = Recurrent.get_tensor_to_export(self) else: tensor_dict = Feedforward.get_tensor_to_export(self) # Add variables to export self._get_variables_to_export(tensor_dict) return tensor_dict def _get_variables_to_export(self, tensor_dict): if tensor_dict is None: tensor_dict = OrderedDict() assert isinstance(tensor_dict, dict) base_on_exemplars = len(tensor_dict) > 0 # Compromise to avoid widget conflict in tensor_viewer def _add_to_dict(key, value): if base_on_exemplars: for exemplar_dict in tensor_dict.values(): exemplar_dict[key] = value else: tensor_dict[key] = value # Add variables to export v_fetches_dict = context.variables_to_export if len(v_fetches_dict) > 0: results = self.model.agent.session.run( list(v_fetches_dict.values())) for key, value in zip(v_fetches_dict.keys(), results): _add_to_dict(key, value) # :: Add numpy arrays that stored in monitor # Add grads stats if necessary if self.th.export_weight_grads: for key, value in context.monitor.grad_dict.items(): _add_to_dict(key, value) # Add general stats if self.th.export_activations: for key, value in context.monitor.stats_dict.items(): _add_to_dict(key, value) return tensor_dict def _take_notes_for_export(self): # Note switch should be turned on if self.th.note_modulus == 0: return # Note cycle should be met if np.mod(self.counter, self.th.note_modulus) != 0: if not (self.counter == 1 and self.th.take_note_in_beginning): return # Loss history should not be blank if not self.batch_loss_stat.last_value: return # Validation history should no be blank if validation is on if self.th.validation_on: if not self.metrics_manager.ready_for_note_taking: return # - Scalars scalars = OrderedDict() scalars['Loss'] = self.batch_loss_stat.running_average self.metrics_manager.update_scalar_dict(scalars) # - Tensors tensors = self._get_tensors_to_export() # Take down self.model.agent.take_down_scalars_and_tensors(scalars, tensors=tensors) self._inter_cut('Notes taken down.', prompt='[Export]') # For quickly note taking if self.th.terminate_on_note: self.th.force_terminate = True def _run_probe(self): if self._probe is None or self.th.probe_modulus == 0: return False if np.mod(self.counter, self.th.probe_modulus) != 0: return False # content = self._probe(self, loss_dict=loss_dict) content = self._probe(self) if content is None or content == '': return self._inter_cut(content, prompt='[Probe]', start_time=self.th.start_time) def _etch(self): if not self.th.etch_on: return if np.mod(self.counter, self.th.etch_modulus) != 0: return pruner = context.pruner assert pruner is not None pruner.etch_all() def _validate_model(self, rnd): if not self.th.validation_on: return False # Validate cycle should be met if np.mod(self.counter, self.th.validate_modulus) != 0: if not (self.counter == 1 and self.th.take_note_in_beginning): return False # Validate training set if necessary if self.th.validate_train_set: train_dict = self.model.validate_model( self.training_set, self.th.val_batch_size, allow_sum=False, verbose=self.th.val_progress_bar) # Record self.metrics_manager.record_stats_on_dataset( self.training_set, train_dict) # Validate val_set and record val_dict = self.model.validate_model( self.validation_set, self.th.val_batch_size, allow_sum=self.th.summary, verbose=self.th.val_progress_bar, seq_detail=self.th.val_info_splits > 0) new_record = self.metrics_manager.record_stats_on_dataset( self.validation_set, val_dict, True, rnd) # Terminator will check early_stop_criterion if new_record appears if new_record and callable(self._terminator): if self._terminator(self.metrics_manager.early_stop_criterion): self.th.force_terminate = True # Validate test set if necessary TODO: BETA if self.th.validate_test_set: test_dict = self.model.validate_model( self.test_set, self.th.val_batch_size, allow_sum=False, verbose=self.th.val_progress_bar) # Record self.metrics_manager.record_stats_on_dataset( self.test_set, test_dict) # Print stats and return new_record flag self.metrics_manager.print_latest_stats('[Validate]', decimals=self.th.val_decimals) return new_record def _snapshot(self): if not self.th.snapshot: return if not self.th.snapshot_cycle > 0: return if np.mod(self.counter - 1, self.th.snapshot_cycle) != 0: return fig = self._snapshot_function(self.model) step = self.counter if self.total_rounds is None else self.total_rounds filename = 'train_{:.2f}.png'.format(step) self.model.agent.save_plot(fig, filename) self._inter_cut("Images saved to '{}'".format(filename), '[Snapshot]') def _save_model(self, inter_cut=False, progress=None): # Update model rounds total_rounds = None if not self.is_online: assert np.isscalar(self.model.rounds) total_rounds = self.model.rounds if progress is not None: assert 0 <= progress <= 1 total_rounds += progress # Save model self.model.agent.save_model(rounds=total_rounds, suffix='train') # Show status print_method = self._inter_cut if inter_cut else console.show_status print_method('Model saved') # endregion : Private Methods # region : Public Methods def get_variables_to_export(self, export_dict=None): """This api is for customized probe method""" return self._get_variables_to_export(export_dict)