def test_examples_can_run_one_step(self): timer = -time.time() # discover all example scripts def walk(pa, dst): for fn in os.listdir(pa): fp = os.path.join(pa, fn) if os.path.isdir(fp): walk(fp, dst) elif fp.endswith('.py'): with codecs.open(fp, 'rb', 'utf-8') as f: cnt = f.read() if re.search(r'''if\s+__name__\s*==\s+(['"])__main__\1:''', cnt): if 'max_step=config.max_step' not in cnt: raise RuntimeError( 'Example script does not have ' 'max_step configuration: {}'.format(fp)) dst.append(fp) return dst examples_dir = os.path.join( os.path.split(os.path.abspath(__file__))[0], '../../tfsnippet/examples') examples_scripts = walk(examples_dir, []) # run all examples scripts for just max_step env_dict = copy.copy(os.environ) for example_script in examples_scripts: print('Run {} ...'.format(example_script)) with TemporaryDirectory() as tempdir: args = [sys.executable, '-u', example_script, '--max_step=1'] subprocess.check_call(args, cwd=tempdir, env=env_dict) print('') # report finished tests print('Finished to run {} example scripts in {}.'.format( len(examples_scripts), humanize_duration(time.time() + timer)))
def test_ScheduledVariable(self): v = ScheduledVariable('v', 123., dtype=tf.int32, model_var=True, collections=['my_variables']) assert_variables(['v'], trainable=False, collections=['my_variables']) with TemporaryDirectory() as tmpdir: saver = tf.train.Saver(var_list=[v.variable]) save_path = os.path.join(tmpdir, 'saved_var') with self.test_session() as sess: ensure_variables_initialized() self.assertEqual(v.get(), 123) self.assertEqual(sess.run(v), 123) v.set(456) self.assertEqual(v.get(), 456) saver.save(sess, save_path) with self.test_session() as sess: saver.restore(sess, save_path) self.assertEqual(v.get(), 456)
def test_save_dir(self): with self.get_session(): a, b, c = _populate_variables() self.assertEqual(get_variable_values([a, b, c]), [1, 2, 3]) with TemporaryDirectory() as tempdir: # test cleanup save_dir save_dir = os.path.join(tempdir, '1') with early_stopping([a, b], save_dir=save_dir) as es: self.assertTrue(es.update(1.)) self.assertTrue( os.path.exists(os.path.join(save_dir, 'latest'))) self.assertFalse(os.path.exists(save_dir)) # test not cleanup save_dir save_dir = os.path.join(tempdir, '2') with early_stopping([a, b], save_dir=save_dir, cleanup=False) as es: self.assertTrue(es.update(1.)) self.assertTrue( os.path.exists(os.path.join(save_dir, 'latest'))) self.assertTrue( os.path.exists(os.path.join(save_dir, 'latest')))
def test_fit(self): values, labels, missing, excludes = self._payload() with TemporaryDirectory() as tmpdir: tf.set_random_seed(1234) donut = Donut(h_for_p_x=lambda x: x, h_for_q_z=lambda x: x, x_dims=5, z_dims=3) trainer = DonutTrainer(donut, max_epoch=3, batch_size=7, valid_step_freq=50, lr_anneal_epochs=2) with self.test_session(): trainer.fit(values=values, labels=labels, missing=missing, mean=1., std=2., excludes=excludes, summary_dir=tmpdir)
class TrainLoop(DisposableContext): """ Training loop object. This class provides a set of convenient methods for writing training loop. It is useful for maintaining epoch and step counters, logging training metrics, memorizing best parameters for early-stopping, etc. An example of using the :class:`TrainLoop`:: import tfsnippet as spt with spt.TrainLoop(param_vars, max_epoch=10, early_stopping=True) as loop: loop.print_training_summary() train_flow = spt.DataFlow.arrays([x, y], batch_size, shuffle=True) for epoch in loop.iter_epochs(): for step, (x, y) in loop.iter_steps(train_flow): step_loss = session.run( [loss, train_op], feed_dict={input_x: x, input_y: y} ) loop.collect_metrics(loss=step_loss) with loop.timeit('valid_time'): valid_loss = session.run( loss, feed_dict={input_x: test_x, input_y: test_y}) loop.collect_metrics(valid_loss=valid_loss) loop.print_logs() The event schedule of a :class:`TrainLoop` can be briefly described as:: # the main training loop events.fire(EventKeys.ENTER_LOOP, self) for epoch in self.iter_epochs(): events.fire(EventKeys.BEFORE_EPOCH, self) for step in self.iter_steps(...): events.fire(EventKeys.BEFORE_STEP, self) ... # execute the step events.reverse_fire(EventKeys.AFTER_STEP, self) events.reverse_fire(EventKeys.AFTER_EPOCH, self) events.fire(EventKeys.EXIT_LOOP, self) # when metrics are fed into the loop by :meth:`collect_metrics` def collect_metrics(self, metrics_dict=None, **kwargs): metrics_dict = merge(metrics_dict, kwargs) events.fire(EventKeys.METRICS_COLLECTED, self, metrics_dict) # when summaries are fed into the loop by :meth:`add_summary` def add_summary(self, summary): events.fire(EventKeys.SUMMARY_ADDED, self, summary) # when metric statistics have been printed as log def print_logs(self): ... events.fire(EventKeys.METRIC_STATS_PRINTED, self, metric_stats) events.fire(EventKeys.TIME_METRIC_STATS_PRINTED, self, time_metric_stats) Warning: If you use early-stopping along with checkpoint, there is one case which is very dangerous: you've already successfully done a training loop, and the early-stopping variables have been restored. But you then recover from the latest checkpoint and continue to train. In this case, the `param_vars` (which is covered by early-stopping) are restored to the best validation step, but the other variables and the internal states of :class:`TrainLoop` are recovered to the last step. Then you obtain a state mismatch, and the behaviour will be un-predictable after this recovery. """ def __init__( self, param_vars, var_groups=None, show_eta=True, print_func=print, max_epoch=None, max_step=None, metric_formatter=DefaultMetricFormatter(), # checkpoint related arguments checkpoint_dir=None, checkpoint_epoch_freq=None, checkpoint_max_to_keep=None, checkpoint_save_objects=None, restore_checkpoint=True, # summary related arguments summary_dir=None, summary_writer=None, summary_graph=None, summary_metric_prefix='metrics/', summary_skip_pattern=re.compile(r'.*(time|timer)$'), summary_commit_freqs=None, # validation and early-stopping related arguments valid_metric_name='valid_loss', valid_metric_smaller_is_better=None, early_stopping=False): """ Construct the :class:`TrainLoop`. Args: param_vars (list[tf.Variable] or dict[str, tf.Variable]): List or dict of variables, optimized during training. var_groups (None or list[str]): Variable groups, the prefixes of variable scopes. A hint for printing the variables summary. (default :obj:`None`) show_eta (bool): Whether or not to show ETA? (default :obj:`True`) print_func ((str) -> None): Function for printing log messages (calling ``print`` by default). An alternative of this argument may be ``getLogger(__name__).info``, such that the log messages will be printed via logging facilities. max_epoch (None or int or tf.Tensor or tf.Variable): The maximum epoch to run. If :obj:`None`, will run for infinite epochs. If ``1``, the epoch counter will be discarded in the output logs. (default :obj:`None`) max_step (None or int or tf.Tensor or tf.Variable): The maximum step to run. If :obj:`None`, will run for infinite steps. Note this limit applies for the total step counter, rather than the epoch-wise step counter. (default :obj:`None`) metric_formatter (MetricFormatter): The training metrics formatter. checkpoint_dir (str): If specified, will save checkpoint files to this directory, when :meth:`make_checkpoint()` is called. checkpoint_epoch_freq (int or None): If specified, will make checkpoint every this number of epochs. If not specified, you must call :meth:`make_checkpoint()` manually. checkpoint_max_to_keep (int or None): Maximum number of checkpoint versions to keep. If :obj:`None` or `0`, keep all versions. checkpoint_save_objects (dict[str, CheckpointSavableObject]): If specified, will save and restore the states of these objects. restore_checkpoint (bool or str): If :obj:`True`, will restore the latest checkpoint. If a str, it should be the path of a checkpoint file, and will restore from this checkpoint. If :obj:`False`, will not restore the from the checkpoint files (but will still save new checkpoints if `checkpoint_dir` if specified). summary_dir (str): Directory for writing TensorFlow summaries. Ignored if `summary_writer` is specified. summary_writer: TensorFlow summary writer for writing metrics. summary_metric_prefix (str): The prefix for the metrics committed to `summary_writer`. This will not affect the summaries added via :meth:`add_summary`. (default "") summary_graph: If specified, log the graph via `summary_writer`. summary_skip_pattern (str or regex): Metrics matching this pattern will be excluded from `summary_writer`. (default ".*(time|timer)$") summary_commit_freqs (dict[str, int] or None): If specified, a metric will be committed to `summary_writer` no more frequent than ``summary_commit_freqs[metric]``. (default :obj:`None`) valid_metric_name (str): Name of the validation metric. valid_metric_smaller_is_better (bool): Whether or not the smaller value is better for validation metric? If not specified, it will be inferred according to `valid_metric_name`: metric names with ``acc`` or ``accuracy`` as suffix imply :obj:`True`, while other names imply :obj:`False`. early_stopping (bool): Whether or not to do early-stopping? (default :obj:`False`) If :obj:`True`, early-stopping will be applied on `param_vars`, according to the validation metric. The variables will only be restored if the training loop is exited without any error or interruption, including the Ctrl+C KeyboardInterrupt. """ # regularize the parameters if not isinstance(param_vars, (dict, OrderedDict)): param_vars = list(param_vars) if isinstance(max_epoch, (tf.Variable, tf.Tensor)): max_epoch = int(max_epoch.eval()) if isinstance(max_step, (tf.Variable, tf.Tensor)): max_step = int(max_step.eval()) if checkpoint_dir is not None: checkpoint_dir = os.path.abspath(checkpoint_dir) if checkpoint_epoch_freq is not None: checkpoint_epoch_freq = int(checkpoint_epoch_freq) if checkpoint_epoch_freq < 1: raise ValueError( '`checkpoint_epoch_freq` must be a positive integer: ' 'got {}'.format(checkpoint_epoch_freq)) if isinstance(restore_checkpoint, six.string_types): if early_stopping: raise ValueError( 'Currently `early_stopping = True` is not supported when ' 'a file path is specified for `restore_checkpoint`.') restore_checkpoint = os.path.abspath(restore_checkpoint) save_objects = dict(checkpoint_save_objects or ()) for key in (TRAIN_LOOP_STATES_CKPT_NAME, EARLY_STOPPING_STATES_CKPT_NAME): if key in save_objects: raise KeyError('Name is reserved for `checkpoint_save_objects`' ': {}'.format(key)) if summary_writer is not None: summary_dir = None own_summary_writer = False elif summary_dir is not None: summary_dir = os.path.abspath(summary_dir) own_summary_writer = True else: own_summary_writer = False smaller_is_better = valid_metric_smaller_is_better if smaller_is_better is None: smaller_is_better = not (valid_metric_name.endswith('acc') or valid_metric_name.endswith('accuracy')) # memorize the parameters self._param_vars = copy.copy(param_vars) self._var_groups = list(var_groups) if var_groups else None self._print_func = print_func self._show_eta = show_eta self._max_epoch = max_epoch self._max_step = max_step self._metric_formatter = metric_formatter self._summary_dir = summary_dir self._summary_writer = summary_writer self._summary_metric_prefix = summary_metric_prefix self._summary_graph = summary_graph self._summary_skip_pattern = summary_skip_pattern self._summary_commit_freqs = dict(summary_commit_freqs or ()) self._own_summary_writer = own_summary_writer self._use_early_stopping = early_stopping self._valid_metric_name = valid_metric_name self._valid_metric_smaller_is_better = smaller_is_better # the event source self._events = EventSource([ EventKeys.ENTER_LOOP, EventKeys.EXIT_LOOP, EventKeys.BEFORE_EPOCH, EventKeys.AFTER_EPOCH, EventKeys.BEFORE_STEP, EventKeys.AFTER_STEP, EventKeys.METRICS_COLLECTED, EventKeys.TIME_METRICS_COLLECTED, EventKeys.METRIC_STATS_PRINTED, EventKeys.TIME_METRIC_STATS_PRINTED, EventKeys.SUMMARY_ADDED, ]) # the restorable train loop states self._states = TrainLoopStates() # initialize the checkpoint saver self._checkpoint_dir = checkpoint_dir self._checkpoint_epoch_freq = checkpoint_epoch_freq self._restore_checkpoint = restore_checkpoint self._checkpoint_saver = None if checkpoint_dir: getLogger(__name__).debug( 'Global variables to save at checkpoints: %s', tf.global_variables()) save_objects[TRAIN_LOOP_STATES_CKPT_NAME] = self._states self._checkpoint_saver = CheckpointSaver( tf.global_variables(), objects=save_objects, save_dir=os.path.join(checkpoint_dir, 'checkpoint'), max_to_keep=checkpoint_max_to_keep, save_meta=False) # the checkpoint saver for early stopping # if checkpoint_dir is None, we postpone the initialization until # enter the loop. self._early_stopping_saver = None self._early_stopping_temp_dir = None # type: TemporaryDirectory if checkpoint_dir is not None and self._use_early_stopping: self._early_stopping_saver = CheckpointSaver(self._param_vars, save_dir=os.path.join( checkpoint_dir, 'early_stopping'), max_to_keep=2, save_meta=False) # euphemeral train loop states self._eta = None self._step_metrics = None # type: MetricLogger self._epoch_metrics = None # type: MetricLogger self._within_epoch = False self._within_step = False self._steps_per_epoch = None # average steps per epoch self._is_best_valid_metric = False self._epoch_start_time = None self._step_start_time = None # the active data flow of current epoch self._data_flow = None # type: DataFlow self._step_data = None # the data of the current step def _enter(self): # open the summary writer if required if self._summary_dir is not None: self._summary_writer = tf.summary.FileWriter( self._summary_dir, graph=self._summary_graph) # create the metric accumulators self._step_metrics = MetricLogger(formatter=self._metric_formatter) self._epoch_metrics = MetricLogger( summary_writer=self._summary_writer, summary_metric_prefix=self._summary_metric_prefix, summary_skip_pattern=self._summary_skip_pattern, summary_commit_freqs=self._summary_commit_freqs, formatter=self._metric_formatter) # create the early-stopping saver if required if self._use_early_stopping: if self._early_stopping_saver is None: self._early_stopping_temp_dir = TemporaryDirectory() dir_path = self._early_stopping_temp_dir.__enter__() self._early_stopping_saver = CheckpointSaver( self._param_vars, save_dir=dir_path, max_to_keep=2, save_meta=False, ) # restore the checkpoint if self._checkpoint_saver is not None: checkpoint_file = None if isinstance(self._restore_checkpoint, six.string_types): checkpoint_file = str(self._restore_checkpoint) elif self._restore_checkpoint: checkpoint_file = self._checkpoint_saver.latest_checkpoint() if checkpoint_file: self._checkpoint_saver.restore(checkpoint_file) self.println( 'Resume training: epoch {}, step {}, from checkpoint {}'. format(self.epoch, self.step, checkpoint_file)) # initialize the eta flags self._eta = ETA() # trigger the event self.events.on(EventKeys.ENTER_LOOP, self) # return self as the context object return self def _exit(self, exc_type, exc_val, exc_tb): try: # close the summary writer if self._own_summary_writer: self._summary_writer.close() self._summary_writer = None self._own_summary_writer = False # restore the early-stopping variables if no error if self._early_stopping_saver is not None: if exc_type is None: es_latest = self._early_stopping_saver.latest_checkpoint() if es_latest is None: # pragma: no cover warnings.warn( 'Early-stopping has never been triggered! ' 'The variables will keep their latest values. ' 'Did you forget to add corresponding metric?') else: self._early_stopping_saver.restore(es_latest) self._early_stopping_saver = None else: # pragma: no cover warnings.warn( 'Early-stopping variables are not restored, because ' 'an error or an interruption has occurred.') finally: try: if self._early_stopping_temp_dir is not None: self._early_stopping_temp_dir.__exit__( exc_type, exc_val, exc_tb) self._early_stopping_temp_dir = None except Exception: getLogger(__name__).warning( 'Failed to cleanup early-stopping temporary directory.', exc_info=True) # clear status self._steps_per_epoch = None self._eta = None # trigger the event self.events.on(EventKeys.EXIT_LOOP, self) def _commit_epoch_stop_time(self): if self._epoch_start_time is not None: duration = time.time() - self._epoch_start_time self.collect_metrics(metrics={EPOCH_TIME_METRIC: duration}) self._epoch_start_time = None def _commit_step_stop_time(self): if self._step_start_time is not None: duration = time.time() - self._step_start_time self.collect_metrics(metrics={STEP_TIME_METRIC: duration}) self._step_start_time = None def get_progress(self): """ Get the progress of training. Returns: float or None: The progress in range ``[0, 1]``, or None if the progress cannot be estimated. """ max_step = self.max_step if max_step is None and self.max_epoch is not None and \ self._steps_per_epoch is not None: max_step = self.max_epoch * self._steps_per_epoch if max_step: if self._within_step and self._step_start_time is not None: # _step_start_time != None, indicating the step not finished return (self.step - 1.) / max_step else: return float(self.step) / max_step elif self.max_epoch is not None: if self._within_epoch and self._epoch_start_time is not None: # _epoch_start_time != None, indicating the epoch not finished return (self.epoch - 1.) / self.max_epoch else: return float(self.epoch) / self.max_epoch @property def param_vars(self): """Get the trainable parameter variables.""" return self._param_vars @property def var_groups(self): """Get the variable groups.""" return self._var_groups @property def max_epoch(self): """Get or set the max value for epoch counter.""" return self._max_epoch @max_epoch.setter def max_epoch(self, value): self._max_epoch = int(value) @property def max_step(self): """Get or set the max value for global step counter.""" return self._max_step @max_step.setter def max_step(self, value): self._max_step = int(value) @property def summary_writer(self): """Get the summary writer instance.""" return self._summary_writer @property def events(self): """ Get the event source. Returns: EventSource: The event source. """ return self._events @property def epoch(self): """Get the epoch counter (starting from 1).""" return self._states.epoch @property def step(self): """Get the global step counter (starting from 1).""" return self._states.step @property def step_data(self): """Get the data of current step.""" return self._step_data @property def use_early_stopping(self): """Whether or not to adopt early-stopping?""" return self._use_early_stopping @property def valid_metric_name(self): """Get the name of the validation metric.""" return self._valid_metric_name @property def valid_metric_smaller_is_better(self): """Whether or not the smaller value is better for validation metric?""" return self._valid_metric_smaller_is_better @property def best_valid_metric(self): """Get the best valid metric.""" return self._states.best_valid_metric @property def within_epoch(self): """Whether or not an epoch is open?""" return self._within_epoch @property def within_step(self): """Whether or not a step is open?""" return self._within_step def make_checkpoint(self): """ Make a checkpoint. This method must be called within an eopch or a step context. For example:: for epoch in loop.iter_epochs(): for [x] in loop.iter_steps(train_data): ... if epoch % 100 == 0: loop.make_checkpoint() """ if not self._checkpoint_saver: raise RuntimeError('Checkpoint directory is not configured.') self._checkpoint_saver.save(self._states.step) def iter_epochs(self): """ Iterate through the epochs. This method can only be called when there's no other epoch loop is being iterated. Furthermore, after exiting this loop, both the epoch metrics as well as the step metrics will be cleared. If `max_epoch` is configured, it will stop at it. Yields: int: The epoch counter (starting from 1). """ def loop_condition(): return ((self._max_epoch is None or self.epoch < self._max_epoch) and (self._max_step is None or self.step < self._max_step)) self._require_entered() if self._within_epoch: raise RuntimeError('Another epoch loop has been opened') try: while loop_condition(): self._states.epoch += 1 self._within_epoch = True self._epoch_start_time = time.time() self.events.fire(EventKeys.BEFORE_EPOCH, self) yield self.epoch self.events.reverse_fire(EventKeys.AFTER_EPOCH, self) self._commit_epoch_stop_time() self._steps_per_epoch = float(self.step) / self.epoch # do checkpoint if configured if self._checkpoint_epoch_freq is not None and \ self.epoch % self._checkpoint_epoch_freq == 0: self.make_checkpoint() finally: self._within_epoch = False self._epoch_start_time = None self._step_metrics.clear() self._epoch_metrics.clear() self._is_best_valid_metric = False def iter_steps(self, data_generator=None): """ Iterate through the steps. This method can only be called when there's no other step loop is being iterated, and an epoch loop is active. Args: data_generator: Optional iterable data to be yielded at every step. This is required if `max_step` is not configured, so as to prevent an infinite step loop. Yields: int or (int, any): The global step counter (starting from 1), or the tuple of ``(step counter, batch data)`` if `data_generator` is specified. """ def loop_condition(): return self._max_step is None or self.step < self._max_step self._require_entered() if not self._within_epoch: raise RuntimeError('Step loop must be opened within active epoch ' 'loop') if self._within_step: raise RuntimeError('Another step loop has been opened') if self._max_step is None and data_generator is None: raise RuntimeError('`data_generator` is required when `max_step` ' 'is not configured, so as to prevent an ' 'unstoppable step loop') try: if data_generator is not None: if isinstance(data_generator, DataFlow): data_flow = data_generator else: def iter_factory(): if data_gen[0] is not None: for batch in data_gen[0]: yield batch data_gen[0] = None # force to use data_generator once data_gen = [data_generator] data_flow = DataFlow.iterator_factory(iter_factory) self._data_flow = data_flow while loop_condition(): # prepare for the step data if self._data_flow is None: yield_obj = self.step + 1 step_data = None else: try: step_data = self._data_flow.next_batch() except StopIteration: break yield_obj = self.step + 1, step_data # yield this step self._states.step += 1 self._within_step = True self._step_data = step_data self._step_start_time = time.time() self.events.fire(EventKeys.BEFORE_STEP, self) try: yield yield_obj except StopIteration: # pragma: no cover # might be caused by call to ``data_flow.next_batch()`` break self.events.reverse_fire(EventKeys.AFTER_STEP, self) self._commit_step_stop_time() finally: self._within_step = False self._step_start_time = None self._data_flow = None self._step_data = None def _require_context(self): self._require_entered() if not self._within_epoch and not self._within_step: raise RuntimeError('An epoch or a step loop is expected, but ' 'neither has been opened') @contextmanager def timeit(self, metric_name): """ Open a context for timing. Args: metric_name (str): Store the timing result in metric of this name. Note that `metric_name` must end with ``time`` or ``timer``, otherwise by default the time values will not be formatted as human readable strings. """ self._require_context() start_time = time.time() yield duration = time.time() - start_time self._collect_metrics({metric_name: duration}, EventKeys.TIME_METRICS_COLLECTED) @contextmanager def metric_collector(self, metric_name): """ Get a :class:`~tfsnippet.utils.StatisticsCollector` for metric. The mean value of the collected metrics will be added to summary after exiting the context. Other statistics will be discarded. Args: metric_name (str): The name of this metric. Yields: StatisticsCollector: The collector for metric values. """ self._require_context() acc = StatisticsCollector() yield acc if acc.has_value: self.collect_metrics(metrics={metric_name: acc.mean}) def _collect_metrics(self, metrics, event_key): self._require_context() # update the metrics self._epoch_metrics.collect_metrics(metrics, global_step=self.step) if self._within_step: self._step_metrics.collect_metrics(metrics, global_step=self.step) self.events.fire(event_key, self, metrics) # update the validation metric def update_valid_metric(d): v = d.get(self.valid_metric_name) if v is not None: if self.best_valid_metric is None or \ (self._valid_metric_smaller_is_better and v < self.best_valid_metric) or \ (not self._valid_metric_smaller_is_better and v > self.best_valid_metric): # we've met a new best metric self._states.best_valid_metric = v self._is_best_valid_metric = True # early-stopping save variables if self._early_stopping_saver is not None: self._early_stopping_saver.save(global_step=self.step) else: self._is_best_valid_metric = False if self.valid_metric_name: if metrics: update_valid_metric(metrics) def collect_metrics(self, metrics=None, **kwargs): """ Add metric values. This method must be called when there's at least an active epoch loop. It will add metrics to the epoch metrics collector, and if there's an active step loop, it will also add metrics to the step metrics collector. If `summary_writer` is configured, it will also write the metrics as summaries onto disk. Furthermore, if `valid_metric_name` is configured, it will also perform early-stopping. Args: metrics (dict[str, float or np.ndarray]): Metric values as dict. **kwargs: Metric values, specified as named arguments. """ if metrics is None: metrics = {} elif metrics is not None and not isinstance(metrics, dict): raise TypeError('`metrics` should be a dict') else: metrics = dict(metrics) metrics.update(kwargs) self._collect_metrics(metrics, EventKeys.METRICS_COLLECTED) def add_summary(self, summary): """ Add a summary object, with ``self.step`` as `global_step`. Args: summary (tf.summary.Summary or bytes): TensorFlow summary object, or serialized summary. """ self._require_entered() self._summary_writer.add_summary(summary, global_step=self.step) self.events.fire(EventKeys.SUMMARY_ADDED, self, summary) def get_eta(self): """ Get the estimated time ahead (ETA). Returns: float or None: The estimated time ahead in seconds, or None if not available. """ progress = self.get_progress() if progress is not None: return self._eta.get_eta(progress) def println(self, message, with_tag=False): """ Print `message` via `print_function`. Args: message (str): Message to be printed. with_tag (bool): Whether or not to add the epoch & step tag? (default :obj:`False`) """ if with_tag: def format_tag(v, max_v, name): if max_v is not None: return '{} {}/{}'.format(name, v, max_v) else: return '{} {}'.format(name, v) if not self._within_step and not self._within_epoch: self._require_context() tags = [] if self._max_epoch != 1: tags.append(format_tag(self.epoch, self._max_epoch, 'Epoch')) tags.append(format_tag(self.step, self._max_step, 'Step')) if self._show_eta: eta = self.get_eta() if eta is not None: tags.append('ETA {}'.format(humanize_duration(eta))) message = '[{}] {}'.format(', '.join(tags), message) self._print_func(message) def print_training_summary(self): """ Print the training summary. The training summary include the following content: 1. Execution environment. 2. Parameters to be optimized during training. """ self._require_entered() self.println( summarize_variables(variables=self._param_vars, title='Trainable Parameters', other_variables_title='Other Parameters', groups=self.var_groups)) self.println('') def print_logs(self): """ Print the training logs. This method will print the collected metrics. If there's an active step loop, it will print metrics from the step metrics collector. Otherwise if there's only an epoch loop, it will print metrics from the epoch metrics accumulator. Note it must be called at the end of an epoch or a step. This is because the metrics of corresponding loop context will be cleared after the logs are printed. Moreover, the epoch or step timer will be committed as metric immediately when this method is called, before printing the logs. """ self._require_entered() metrics = None if self._within_step: self._commit_step_stop_time() metrics = self._step_metrics elif self._within_epoch: self._commit_epoch_stop_time() metrics = self._epoch_metrics else: self._require_context() best_mark = ' (*)' if self._is_best_valid_metric else '' self.println(metrics.format_logs() + best_mark, with_tag=True) self._is_best_valid_metric = False # fire the metric stats printed events metric_stats = { key: val for key, val in six.iteritems(metrics.metrics) if not TIME_METRIC_PATTERN.match(key) } time_metric_stats = { key: val for key, val in six.iteritems(metrics.metrics) if TIME_METRIC_PATTERN.match(key) } if metric_stats: self.events.fire(EventKeys.METRIC_STATS_PRINTED, self, metric_stats) if time_metric_stats: self.events.fire(EventKeys.TIME_METRIC_STATS_PRINTED, self, time_metric_stats) metrics.clear()
def test_summary_writer(self): def read_summary(summary_dir): # read the metric summary loss_steps = [] loss_values = [] valid_loss_steps = [] valid_loss_values = [] x_steps = [] x_values = [] tags = set() event_file_path = os.path.join( summary_dir, os.listdir(summary_dir)[0]) for e in tf.train.summary_iterator(event_file_path): for v in e.summary.value: tags.add(v.tag) if v.tag == 'metrics/loss': loss_steps.append(e.step) loss_values.append(v.simple_value) elif v.tag == 'metrics/valid_loss': valid_loss_steps.append(e.step) valid_loss_values.append(v.simple_value) elif v.tag == 'x': x_steps.append(e.step) x_values.append(v.simple_value) return (tags, loss_steps, loss_values, valid_loss_steps, valid_loss_values, x_steps, x_values) # test enable summary with `summary_dir` with TemporaryDirectory() as tempdir: with TrainLoop([], max_epoch=2, summary_dir=tempdir, summary_graph=tf.get_default_graph()) as loop: self.assertIsInstance(loop.summary_writer, tf.summary.FileWriter) self.assertIsNone(loop._early_stopping) for epoch in loop.iter_epochs(): for _, loss in loop.iter_steps([0.7, 0.6, 0.8]): loop.collect_metrics(loss=epoch + loss) loop.collect_metrics(valid_loss=epoch) with self.test_session(): summary_op = tf.summary.scalar('x', tf.constant(1.23)) loop.add_summary(summary_op.eval()) obj = read_summary(tempdir) self.assertEqual( ['metrics/loss', 'metrics/valid_loss', 'x'], sorted(obj[0]) ) np.testing.assert_equal(obj[1], [1, 2, 3, 4, 5, 6]) np.testing.assert_almost_equal( obj[2], [1.7, 1.6, 1.8, 2.7, 2.6, 2.8] ) np.testing.assert_equal(obj[3], [3, 6]) np.testing.assert_almost_equal(obj[4], [1, 2]) np.testing.assert_equal(obj[5], [6]) np.testing.assert_almost_equal(obj[6], [1.23]) # test enable summary with `summary_writer` with TemporaryDirectory() as tempdir: sw = tf.summary.FileWriter(tempdir) with TrainLoop([], max_epoch=2, summary_writer=sw) as loop: self.assertIs(loop.summary_writer, sw) self.assertIsNone(loop._early_stopping) self.assertIs(loop._summary_writer, sw) for epoch in loop.iter_epochs(): for _, loss in loop.iter_steps([0.7, 0.6, 0.8]): loop.collect_metrics(loss=epoch + loss) loop.collect_metrics(valid_loss=epoch) sw.close() self.assertEqual( sorted(read_summary(tempdir)[0]), ['metrics/loss', 'metrics/valid_loss'] ) with TemporaryDirectory() as tempdir: sw = tf.summary.FileWriter(tempdir) with TrainLoop([], max_epoch=2, summary_writer=sw) as loop: self.assertIsNone(loop._early_stopping) self.assertIs(loop._summary_writer, sw) for epoch in loop.iter_epochs(): for _, loss in loop.iter_steps([0.7, 0.6, 0.8]): loop.collect_metrics(loss=epoch + loss) loop.collect_metrics(valid_loss=epoch) sw.close() self.assertEqual( sorted(read_summary(tempdir)[0]), ['metrics/loss', 'metrics/valid_loss'] )
class EarlyStopping(DisposableContext): """ Early-stopping context object. This class provides a object for memorizing the parameters for best metric, in an early-stopping context. An example of using this context: .. code-block:: python with EarlyStopping(param_vars) as es: ... es.update(loss, global_step) ... Where ``es.update(loss, global_step)`` should cause the parameters to be saved on disk if `loss` is better than the current best metric. One may also get the current best metric via ``es.best_metric``. Notes: If no loss is given via ``es.update``, then the variables would keep their latest values when closing an early-stopping object. """ def __init__(self, param_vars, initial_metric=None, checkpoint_dir=None, smaller_is_better=True, restore_on_error=False, cleanup=True, name=None): """ Construct the :class:`EarlyStopping`. Args: param_vars (list[tf.Variable] or dict[str, tf.Variable]): List or dict of variables to be memorized. If a dict is specified, the keys of the dict would be used as the serializations keys via :class:`VariableSaver`. initial_metric (float or tf.Tensor or tf.Variable): The initial best metric (for recovering from previous session). checkpoint_dir (str): The directory where to save the checkpoint files. If not specified, will use a temporary directory. smaller_is_better (bool): Whether or not it is better to have smaller metric values? (default :obj:`True`) restore_on_error (bool): Whether or not to restore the memorized parameters even on error? (default :obj:`False`) cleanup (bool): Whether or not to cleanup the checkpoint directory on exit? This argument will be ignored if `checkpoint_dir` is :obj:`None`, where the temporary directory will always be deleted on exit. name (str): Name scope of all TensorFlow operations. (default "early_stopping"). """ # regularize the parameters if not param_vars: raise ValueError('`param_vars` must not be empty') if isinstance(initial_metric, (tf.Tensor, tf.Variable)): initial_metric = initial_metric.eval() if checkpoint_dir is not None: checkpoint_dir = os.path.abspath(checkpoint_dir) # memorize the parameters self._param_vars = copy.copy(param_vars) self._checkpoint_dir = checkpoint_dir self._smaller_is_better = smaller_is_better self._restore_on_error = restore_on_error self._cleanup = cleanup self._name = name # internal states of the object self._best_metric = initial_metric self._ever_updated = False self._temp_dir_ctx = None self._saver = None # type: VariableSaver def _enter(self): # open a temporary directory if the checkpoint dir is not specified if self._checkpoint_dir is None: self._temp_dir_ctx = TemporaryDirectory() self._checkpoint_dir = self._temp_dir_ctx.__enter__() else: makedirs(self._checkpoint_dir, exist_ok=True) # create the variable saver self._saver = VariableSaver(self._param_vars, self._checkpoint_dir) # return self as the context object return self def _exit(self, exc_type, exc_val, exc_tb): try: # restore the variables # exc_info = (exc_type, exc_val, exc_tb) if exc_type is None or exc_type is KeyboardInterrupt or \ self._restore_on_error: self._saver.restore(ignore_non_exist=True) finally: # cleanup the checkpoint directory try: if self._temp_dir_ctx is not None: self._temp_dir_ctx.__exit__(exc_type, exc_val, exc_tb) elif self._cleanup: if os.path.exists(self._checkpoint_dir): shutil.rmtree(self._checkpoint_dir) except Exception: # pragma: no cover getLogger(__name__).error( 'Failed to cleanup validation save dir %r.', self._checkpoint_dir, exc_info=True) # warning if metric never updated if not self._ever_updated: warnings.warn('Early-stopping metric has never been updated. ' 'The variables will keep their latest values. ' 'Did you forget to add corresponding metric?') def update(self, metric, global_step=None): """ Update the best metric. Args: metric (float): New metric value. global_step (int): Optional global step counter. Returns: bool: Whether or not the best loss has been updated? """ self._require_entered() self._ever_updated = True if self._best_metric is None or \ (self._smaller_is_better and metric < self._best_metric) or \ (not self._smaller_is_better and metric > self._best_metric): self._saver.save(global_step) self._best_metric = metric return True return False @property def best_metric(self): """Get the current best loss.""" return self._best_metric @property def ever_updated(self): """Check whether or not `update` method has ever been called.""" return self._ever_updated
def test_summary_writer(self): with TemporaryDirectory() as tempdir: # generate the metric summary with contextlib.closing(tf.summary.FileWriter(tempdir)) as sw: logger = MetricLogger( sw, summary_skip_pattern=r'.*(time|timer)$', summary_commit_freqs={'every_two': 2} ) step = 0 for epoch in range(1, 3): for data in range(10): step += 1 logger.collect_metrics({'acc': step * 100 + data}, step) logger.collect_metrics({'time': epoch}, step) logger.collect_metrics({'every_two': step * 2}, step) with self.test_session(use_gpu=False): logger.collect_metrics( {'valid_loss': -epoch}, tf.constant(step)) # read the metric summary acc_steps = [] acc_values = [] valid_loss_steps = [] valid_loss_values = [] every_two_steps = [] every_two_values = [] tags = set() event_file_path = os.path.join(tempdir, os.listdir(tempdir)[0]) for e in tf.train.summary_iterator(event_file_path): for v in e.summary.value: tags.add(v.tag) if v.tag == 'acc': acc_steps.append(e.step) acc_values.append(v.simple_value) elif v.tag == 'valid_loss': valid_loss_steps.append(e.step) valid_loss_values.append(v.simple_value) elif v.tag == 'every_two': every_two_steps.append(e.step) every_two_values.append(v.simple_value) self.assertEqual(sorted(tags), ['acc', 'every_two', 'valid_loss']) np.testing.assert_equal(acc_steps, np.arange(1, 21)) np.testing.assert_almost_equal( acc_values, np.arange(1, 21) * 100 + np.concatenate([ np.arange(10), np.arange(10) ]) ) np.testing.assert_equal(every_two_steps, np.arange(1, 21, 2)) np.testing.assert_almost_equal( every_two_values, np.arange(1, 21, 2) * 2 ) np.testing.assert_equal(valid_loss_steps, [10, 20]) np.testing.assert_almost_equal( valid_loss_values, [-1, -2] )
def early_stopping(param_vars, initial_metric=None, save_dir=None, smaller_is_better=True, restore_on_error=False, cleanup=True, name=None): """Open a context to memorize the values of parameters at best metric. This method will open a context with an object to memorize the best metric for early-stopping. An example of using this early-stopping context is: with early_stopping(param_vars) as es: ... es.update(loss, global_step) ... Where ``es.update(loss, global_step)`` should cause the parameters to be saved on disk if `loss` is better than the current best metric. One may also get the best metric via ``es.best_metric``. Note that if no loss is given via ``es.update``, then the variables would keep their latest values when exiting the early-stopping context. Parameters ---------- param_vars : list[tf.Variable] | dict[str, tf.Variable] List or dict of variables to be memorized. If a dict is specified, the keys of the dict would be used as the serializations keys via `VariableSaver`. initial_metric : float | tf.Tensor | tf.Variable The initial best loss (usually for recovering from previous session). save_dir : str The directory where to save the variable values. If not specified, will use a temporary directory. smaller_is_better : bool Whether or not the less, the better loss? (default True) restore_on_error : bool Whether or not to restore the memorized parameters even on error? (default False) cleanup : bool Whether or not to cleanup the saving directory on exit? This argument will be ignored if `save_dir` is None, while the temporary directory will always be deleted on exit. name : str Optional name of this scope. Yields ------ _EarlyStopping The object to receive loss during early-stopping context. """ if not param_vars: raise ValueError('`param_vars` must not be empty.') if save_dir is None: with TemporaryDirectory() as tempdir: with early_stopping(param_vars, initial_metric=initial_metric, save_dir=tempdir, cleanup=False, smaller_is_better=smaller_is_better, restore_on_error=restore_on_error, name=name) as es: yield es else: if isinstance(initial_metric, (tf.Tensor, tf.Variable)): initial_metric = initial_metric.eval() with tf.name_scope(name): saver = VariableSaver(param_vars, save_dir) save_dir = os.path.abspath(save_dir) makedirs(save_dir, exist_ok=True) es = _EarlyStopping(saver, best_metric=initial_metric, smaller_is_better=smaller_is_better) try: yield es except Exception as ex: if isinstance(ex, KeyboardInterrupt) or restore_on_error: saver.restore() raise else: saver.restore() finally: if cleanup: try: if os.path.exists(save_dir): shutil.rmtree(save_dir) except Exception: getLogger(__name__).error( 'Failed to cleanup validation save dir %r.', save_dir, exc_info=True ) if not es.ever_updated: warnings.warn( 'Early-stopping metric has never been updated. ' 'The variables will keep their latest values. ' 'Did you forget to add corresponding metric?' )
def test_save_restore(self): class MyObject(CheckpointSavableObject): def __init__(self, value): self.value = value def get_state(self): return self.__dict__ def set_state(self, state): keys = list(self.__dict__) for k in keys: if k not in state: self.__dict__.pop(k) for k in state: self.__dict__[k] = state[k] with TemporaryDirectory() as tmpdir, \ self.test_session() as sess: save_dir = os.path.join(tmpdir, 'saves') v = tf.get_variable('v', dtype=tf.int32, initializer=12) sv = ScheduledVariable('sv', dtype=tf.float32, initial_value=34) obj = MyObject(56) obj2 = MyObject(90) ensure_variables_initialized() # test construct a saver upon empty directory saver = CheckpointSaver([v, sv], save_dir, objects={'obj': obj, 'obj2': obj2}) self.assertIsNone(saver.latest_checkpoint()) with pytest.raises(IOError, match='No checkpoint file is found'): saver.restore_latest() saver.restore_latest(ignore_non_exist=True, session=sess) # save the first checkpoint ckpt_0 = saver.save(0) self.assertEqual(saver.latest_checkpoint(), ckpt_0) # now we change the states sess.run(tf.assign(v, 1212)) sv.set(3434) obj.value = 5656 obj.value2 = 7878 obj2.value = 9090 ckpt_1 = saver.save(1, session=sess) self.assertEqual(saver.latest_checkpoint(), ckpt_1) # construct a saver on existing checkpoint directory saver = CheckpointSaver([v, sv], save_dir, objects={'obj': obj, 'obj2': obj2}) self.assertEqual(saver.latest_checkpoint(), ckpt_1) # restore the latest checkpoint saver.restore_latest() self.assertListEqual(sess.run([v, sv]), [1212, 3434]) self.assertEqual(obj.value, 5656) self.assertEqual(obj.value2, 7878) self.assertEqual(obj2.value, 9090) # restore a previous checkpoint saver.restore(ckpt_0, sess) self.assertListEqual(sess.run([v, sv]), [12, 34]) self.assertEqual(obj.value, 56) self.assertFalse(hasattr(obj, 'value2')) self.assertEqual(obj2.value, 90) # try to restore only a partial of the variables and objects saver = CheckpointSaver([v], save_dir, objects={'obj': obj}) saver.restore_latest() self.assertListEqual(sess.run([v, sv]), [1212, 34]) self.assertEqual(obj.value, 5656) self.assertEqual(obj.value2, 7878) self.assertEqual(obj2.value, 90) # try to restore a non-exist object saver = CheckpointSaver([v], save_dir, objects={'obj3': obj}) with pytest.raises(KeyError, match='Object `obj3` not found in the ' 'checkpoint'): saver.restore_latest()
def test_constructor(self): with TemporaryDirectory() as tmpdir: v1 = tf.get_variable('v1', dtype=tf.int32, shape=()) with tf.variable_scope('parent'): v2 = tf.get_variable('v2', dtype=tf.int32, shape=()) sv = ScheduledVariable('sv', dtype=tf.float32, initial_value=123) obj = Mock( spec=CheckpointSavableObject, get_state=Mock(return_value={'value': 123}), set_state=Mock() ) obj2 = Mock( spec=CheckpointSavableObject, get_state=Mock(return_value={'value': 456}), set_state=Mock() ) saver = CheckpointSaver([v1, sv, v2], tmpdir + '/1', objects={'obj': obj, 'obj2': obj2}) self.assertEqual(saver.save_dir, tmpdir + '/1') self.assertIsNone(saver._saver._max_to_keep) self.assertTrue(saver.save_meta) self.assertDictEqual( saver._var_dict, {'v1': v1, 'parent/v2': v2, 'sv': sv.variable, CHECKPOINT_VAR_NAME: saver._serial_var.variable} ) self.assertIsInstance(saver.saver, tf.train.Saver) saver = CheckpointSaver( {'vv1': v1, 'v22': v2, 'svv': sv}, tmpdir + '/2', objects={'oobj': obj, 'obj2': obj2}, filename='variables.dat', max_to_keep=3, save_meta=False ) self.assertEqual(saver.save_dir, tmpdir + '/2') self.assertEqual(saver._saver._max_to_keep, 3) self.assertFalse(saver.save_meta) self.assertDictEqual( saver._var_dict, {'vv1': v1, 'v22': v2, 'svv': sv.variable, CHECKPOINT_VAR_NAME: saver._serial_var.variable} ) self.assertIsInstance(saver.saver, tf.train.Saver) with pytest.raises(TypeError, match='Not a variable'): _ = CheckpointSaver([object()], tmpdir) with pytest.raises(TypeError, match='Not a variable'): _ = CheckpointSaver([tf.constant(123.)], tmpdir) with pytest.raises(TypeError, match='Not a savable object'): _ = CheckpointSaver([], tmpdir, {'obj': object()}) with pytest.raises(TypeError, match='Not a savable object'): _ = CheckpointSaver([], tmpdir, {'obj': tf.constant(0.)}) with pytest.raises(KeyError, match='Name is reserved for `variables`'): _ = CheckpointSaver( [tf.get_variable(CHECKPOINT_VAR_NAME, dtype=tf.int32, initializer=0)], tmpdir ) with pytest.raises(KeyError, match='Name is reserved for `variables`'): _ = CheckpointSaver( {CHECKPOINT_VAR_NAME: tf.get_variable( 'a', dtype=tf.int32, initializer=0)}, tmpdir ) with pytest.raises(KeyError, match='Name is reserved for `objects`'): _ = CheckpointSaver([], tmpdir, {CHECKPOINT_VAR_NAME: obj})
def test_errors(self): with TemporaryDirectory() as tempdir: with pytest.raises(ValueError, match='`checkpoint_epoch_freq` must ' 'be a positive integer'): with TrainLoop([], checkpoint_dir=tempdir, checkpoint_epoch_freq=0): pass with pytest.raises(ValueError, match='Currently `early_stopping = True` is not ' 'supported when a file path is ' 'specified for `restore_checkpoint`'): with TrainLoop([], checkpoint_dir=tempdir, early_stopping=True, restore_checkpoint=os.path.join( tempdir, 'checkpoint.dat')): pass with pytest.raises(RuntimeError, match='Checkpoint directory is ' 'not configured'): with TrainLoop([]) as loop: loop.make_checkpoint() obj = Mock( spec=CheckpointSavableObject, get_state=Mock(return_value={}), set_state=Mock() ) with pytest.raises(KeyError, match='Name is reserved for ' '`checkpoint_save_objects`'): with TrainLoop([], checkpoint_dir=tempdir, checkpoint_save_objects={ TRAIN_LOOP_STATES_CKPT_NAME: obj }): pass with pytest.raises(KeyError, match='Name is reserved for ' '`checkpoint_save_objects`'): with TrainLoop([], checkpoint_dir=tempdir, checkpoint_save_objects={ EARLY_STOPPING_STATES_CKPT_NAME: obj }): pass with pytest.raises( RuntimeError, match='Another epoch loop has been opened'): with TrainLoop([], max_epoch=10) as loop: for _ in loop.iter_epochs(): for _ in loop.iter_epochs(): pass with pytest.raises( RuntimeError, match='Step loop must be opened within active ' 'epoch loop'): with TrainLoop([], max_step=10) as loop: for _ in loop.iter_steps(): pass with pytest.raises( RuntimeError, match='Another step loop has been opened'): with TrainLoop([], max_epoch=10, max_step=10) as loop: for _ in loop.iter_epochs(): for _ in loop.iter_steps(): for _ in loop.iter_steps(): pass def require_context(): return pytest.raises( RuntimeError, match='An epoch or a step loop is expected, ' 'but neither has been opened') with require_context(): with TrainLoop([]) as loop: with loop.timeit('timer'): pass with require_context(): with TrainLoop([]) as loop: with loop.metric_collector('metric'): pass with require_context(): with TrainLoop([]) as loop: loop.collect_metrics(loss=1.) with require_context(): with TrainLoop([]) as loop: loop.println('', with_tag=True) with require_context(): with TrainLoop([]) as loop: loop.print_logs() with pytest.raises( RuntimeError, match='`data_generator` is required when ' '`max_step` is not configured, so as to ' 'prevent an unstoppable step loop'): with TrainLoop([], max_epoch=10) as loop: for _ in loop.iter_epochs(): for _ in loop.iter_steps(): pass with pytest.raises( TypeError, match='`metrics` should be a dict'): with TrainLoop([], max_epoch=10) as loop: for _ in loop.iter_epochs(): loop.collect_metrics(())
def test_checkpoint(self): class MyObject(CheckpointSavableObject): def __init__(self): self.value = 123 def get_state(self): return {'value': self.value} def set_state(self, state): self.value = state['value'] o = MyObject() var = ScheduledVariable('var', initial_value=456, dtype=tf.int32) with self.test_session() as sess, \ TemporaryDirectory() as tempdir: ensure_variables_initialized() with TrainLoop([var.variable], checkpoint_dir=tempdir, checkpoint_save_objects={'o': o}) as loop: loop.make_checkpoint() # test restore_checkpoint == True o.value = 1234 var.set(4567) self.assertEqual(o.value, 1234) self.assertEqual(var.get(), 4567) with TrainLoop([var.variable], checkpoint_dir=tempdir, checkpoint_save_objects={'o': o}): self.assertEqual(loop.epoch, 0) self.assertEqual(loop.step, 0) self.assertEqual(o.value, 123) self.assertEqual(var.get(), 456) # test restore_checkpoint == False, and generate new checkpoints o.value = 1234 var.set(4567) with TrainLoop([var.variable], checkpoint_dir=tempdir, checkpoint_save_objects={'o': o}, checkpoint_epoch_freq=2, restore_checkpoint=False, max_epoch=8) as loop: self.assertEqual(loop.epoch, 0) self.assertEqual(loop.step, 0) self.assertEqual(o.value, 1234) self.assertEqual(var.get(), 4567) for epoch in loop.iter_epochs(): for _ in loop.iter_steps([1, 1]): pass o.value = 9120 + epoch var.set(9450 + epoch) # restore from latest with TrainLoop([var.variable], checkpoint_dir=tempdir, checkpoint_save_objects={'o': o}) as loop: self.assertEqual(loop.epoch, 8) self.assertEqual(loop.step, 16) self.assertEqual(o.value, 9128) self.assertEqual(var.get(), 9458) # restore from specified file for epoch in [2, 4, 6, 8]: restore_checkpoint = os.path.join( tempdir, 'checkpoint/checkpoint.dat-{}'.format(epoch * 2)) with TrainLoop([var.variable], checkpoint_dir=tempdir, checkpoint_save_objects={'o': o}, restore_checkpoint=restore_checkpoint) as loop: self.assertEqual(loop.epoch, epoch) self.assertEqual(loop.step, epoch * 2) self.assertEqual(o.value, 9120 + epoch) self.assertEqual(var.get(), 9450 + epoch)
def test_errors(self): with TemporaryDirectory() as tempdir: a = tf.get_variable('a', initializer=1, dtype=tf.int32) with pytest.raises(ValueError, match='At least 2 versions should be kept'): _ = VariableSaver([a], tempdir, max_versions=1)