Esempio n. 1
0
    def after_run(self, run_context, run_values):
        """
        overridden from :obj:`tf.train.SessionRunHook`. See its documentation
        for more information
        """
        execution_time = time.time() - self._time_run_start
        data = run_values.results['data']
        global_step = run_context.session.run(self._global_step_tensor)
        mode = self._callback_or_handler.mode
        number_iterations_per_epoch = (
            self._callback_or_handler.number_iterations_per_epoch)
        iteration_number = (
            self._callback_or_handler.iteration_info.iteration_number)

        epoch_number, iteration_number, summary_step = (
            model_utils.get_iteration_stat_from_global_step(
                mode=mode,
                global_step=global_step,
                previous_iteration_number=iteration_number,
                number_iterations_per_epoch=number_iterations_per_epoch,
                max_number_of_iterations_per_epoch=self.
                _max_number_of_iterations_per_epoch))
        is_last_iteration = not iteration_number % number_iterations_per_epoch
        iteration_info = RunIterationInfo(epoch_number=epoch_number,
                                          iteration_number=iteration_number,
                                          execution_time=execution_time,
                                          is_last_iteration=is_last_iteration,
                                          session_run_context=run_context)
        self._callback_or_handler.iteration_info = iteration_info
        self._callback_or_handler.summary_step = summary_step
        self._callback_or_handler(**data)
Esempio n. 2
0
    def test_set_save_name(self, provide_save_name, provide_iteration_info):
        saver = SaverCallback(inbound_nodes=[],
                              save_prefix="prefix",
                              save_suffix="suffix").build()
        saver.log_dir = "save_target"
        save_name = "some_new_name" if provide_save_name else None
        if provide_iteration_info:
            iteration_info = RunIterationInfo(epoch_number=5,
                                              iteration_number=10)
            saver.iteration_info = iteration_info
            saver._sample_index = 3

        saver.set_save_name(save_name)

        if provide_save_name:
            save_name_must = "save_target/prefix-some_new_name-suffix"
        else:
            if provide_iteration_info:
                save_name_must = (
                    "save_target/prefix-"
                    "epoch_{:03d}_iter_{:05d}_sample_{:03d}-suffix"
                    "".format(5, 10, 3))
            else:
                save_name_must = "save_target/prefix-suffix"

        self.assertEqual(save_name_must, saver.save_name)

        saver.log_dir = self.get_temp_dir()
        saver.save_name_depth = -1
        saver.set_save_name("subdir/temp/file-654")
        self.assertTrue(
            os.path.exists(
                os.path.join(self.get_temp_dir(), "prefix-subdir/temp")))
        saver.set_save_name("subdir/temp/file-654")
    def test_on_iteration_end(self, evaluate):
        buffer_callback = _DummySummarizerCallback(inbound_nodes=[])
        results = []
        for i, each_batch in enumerate(self.data):
            if i == self.num_batches - 1:
                buffer_callback.iteration_info = RunIterationInfo(
                    is_last_iteration=True)
            if evaluate:
                evaluate_batch = self.evaluate[i]
            else:
                evaluate_batch = None
            results.append(buffer_callback.on_iteration_end(
                evaluate=evaluate_batch, **each_batch))

        if not evaluate:
            result_last_must = self._get_result_must_in_interval()
            results_must = [None, None, result_last_must]
        else:
            result_last_list = [self._get_result_must_in_interval(7, 13),
                                self._get_result_must_in_interval(13, None)]
            result_last = {k: np.concatenate(
                [each_res[k] for each_res in result_last_list], 0)
                for k in result_last_list[0]}
            results_must = [None,
                            self._get_result_must_in_interval(0, 7),
                            result_last]

        for each_result, each_result_must in zip(results, results_must):
            if each_result_must is None:
                self.assertIsNone(each_result)
            else:
                self.assertAllClose(each_result,
                                    each_result_must)
 def test_iteration_info_setter(self):
     iteration_info = RunIterationInfo(1, 100, 10)
     callbacks, _ = self._get_callbacks_and_incoming_nucleotides()
     callbacks_handler = CallbacksHandler(callbacks=callbacks).build()
     callbacks_handler.iteration_info = iteration_info
     for callback in callbacks:
         self.assertTupleEqual(tuple(iteration_info),
                               tuple(callback.iteration_info))
Esempio n. 5
0
 def __init__(self, callbacks: Union[List[CoordinatorCallback],
                                     Dict[str, CoordinatorCallback],
                                     CoordinatorCallback]):
     self.callbacks = None  # type: Dict[str, CoordinatorCallback]
     super().__init__(callbacks=callbacks)
     self._iteration_info = RunIterationInfo(0, 0, 0.0)
     self._number_iterations_per_epoch = None  # type: Optional[int]
     self._log_dir = None  # type: Optional[str]
     self._summary_writer = None  # type: Optional[tf.summary.FileWriter]
     self._summary_step = None  # type: Optional[int]
Esempio n. 6
0
 def _get_inputs(self, epoch_number, monitor_mode, sess):
     run_context = tf.train.SessionRunContext([], session=sess)
     run_context.request_stop = MagicMock(wraps=run_context.request_stop)
     iteration_info = RunIterationInfo(epoch_number, epoch_number * 100, 0,
                                       True, run_context)
     monitor_data = (self.monitor_data_max
                     if monitor_mode == "max" else self.monitor_data_min)
     inputs = {"monitor": monitor_data[epoch_number - 1]}
     should_stop = False
     if epoch_number == self.last_epoch:
         should_stop = True
     return inputs, iteration_info, should_stop
Esempio n. 7
0
 def __init__(self,
              *,
              inbound_nodes,
              name=None,
              incoming_keys_mapping=None):
     super(CoordinatorCallback,
           self).__init__(inbound_nodes=inbound_nodes,
                          name=name,
                          incoming_keys_mapping=incoming_keys_mapping)
     self._iteration_info = RunIterationInfo(0, 0, 0.0)
     self._number_iterations_per_epoch = None  # type: Optional[int]
     self._log_dir = None  # type: Optional[str]
     self._summary_writer = None  # type: Optional[tf.summary.FileWriter]
     self._summary_step = None  # type: Optional[int]
Esempio n. 8
0
    def test_on_iteration_end(self, evaluator_is_last_iteration,
                              evaluator_call, with_summary_writer,
                              epoch_number, iteration_number,
                              is_last_iteration):
        def _evaluator_call(**inputs):
            return {"kpi": 1}

        log_dir = self.get_temp_dir()

        evaluator_is_last_iteration.side_effect = lambda x: x
        evaluator_call.side_effect = _evaluator_call
        accumulator = KPIAccumulator()
        plugin = KPIPlugin()

        evaluator = KPIEvaluator(plugin, accumulator).build()
        evaluator.clear_state = MagicMock(return_value=None)

        callback = convert_evaluator_to_callback(evaluator, 3)

        callback.iteration_info = RunIterationInfo(epoch_number,
                                                   iteration_number, 0,
                                                   is_last_iteration)

        if with_summary_writer:
            summary_writer = tf.summary.FileWriter(log_dir)
            callback.summary_writer = summary_writer
            callback.summary_step = 7

        callback.on_iteration_end(**self.data_batch)

        if epoch_number == 0 or epoch_number % 3 or iteration_number > 1:
            evaluator.clear_state.assert_not_called()
        else:
            evaluator.clear_state.assert_called_once_with()

        evaluator_call.assert_called_once_with(**self.data_batch)

        self.assertEmpty(plugin.savers)
        evaluator_is_last_iteration.assert_called_once_with(is_last_iteration)
        if not with_summary_writer:
            self.assertEmpty(accumulator.savers)
        else:
            self.assertEqual(1, len(accumulator.savers))
            summary_saver = accumulator.savers[0]
            self.assertIsInstance(summary_saver, TfSummaryKPISaver)
            self.assertIs(callback.summary_writer,
                          summary_saver.summary_writer)
            self.assertEqual(7, summary_saver.summary_step)
Esempio n. 9
0
    def test_on_iteration_end(self, mock_stdout):
        callback = BaseLogger(inbound_nodes=['node1'])
        mode = 'train'
        epoch = 10
        iter_n = 112
        time_exec = 10.1
        iter_per_epoch = 256
        data = {
            'loss': {
                'total_loss': 10,
                'loss_2': 20
            },
            'metric': {
                'metric_1': np.array([10]),
                'metric_2': np.array([10, 10])
            }
        }

        iteration_info = RunIterationInfo(epoch_number=epoch,
                                          iteration_number=iter_n,
                                          execution_time=time_exec)
        callback.iteration_info = iteration_info
        callback.mode = mode
        callback.number_iterations_per_epoch = iter_per_epoch
        callback.on_iteration_end(**data)
        time_remain = time_exec / iter_n * iter_per_epoch - time_exec
        printed = mock_stdout.getvalue()
        printed_lines = printed.split('\n')[:-1]
        self.assertEqual(len(printed_lines), 2)
        self.assertEqual(len(printed_lines[0]), len(printed_lines[1]))

        printed_names = printed_lines[0].strip().split('|')[1:-1]
        printed_values = printed_lines[1].strip().split('|')[1:-1]
        self.assertListEqual([len(n) for n in printed_names],
                             [len(v) for v in printed_values])
        printed_names_str = list(map(str.strip, printed_names))
        printed_values_str = list(map(str.strip, printed_values))
        names_after_stat = printed_names_str[5:]

        self.assertListEqual(names_after_stat,
                             ['loss_2', 'total_loss', 'metric_1'])
        printed_dict = dict(zip(printed_names_str, printed_values_str))
        printed_dict = {
            k: float(v) if k not in ['mode', 'iter'] else v
            for k, v in printed_dict.items()
        }
        not_printable_key = 'metric_wrong'
        self.assertNotIn(not_printable_key, printed_dict)
        printed_must = {
            'mode': mode,
            'epoch': epoch,
            'time_exec, [s]': round(time_exec, 2),
            'time_remain, [s]': round(time_remain, 2),
            'iter': '{}/{}'.format(iter_n, iter_per_epoch),
            'total_loss': 10,
            'loss_2': 20,
            'metric_1': np.array([10])
        }
        self.assertDictEqual(printed_dict, printed_must)

        mock_stdout.truncate(0)
        callback.on_iteration_end(**data)
        printed = mock_stdout.getvalue()
        printed_lines = printed.split('\n')[:-1]
        self.assertEqual(len(printed_lines), 1)