Ejemplo n.º 1
0
    def test_tf_traceability(self):
        if os.path.exists(self.tf_dir) and os.path.isdir(self.tf_dir):
            shutil.rmtree(self.tf_dir)

        trace = Traceability(save_path=self.tf_dir)
        est = _build_estimator(
            fe.build(model_fn=LeNet, optimizer_fn="adam",
                     model_name='tfLeNet'), trace)

        trace.system = est.system
        trace.system.epoch_idx = 1
        trace.system.summary.name = "TF Test"

        trace.on_begin(Data())
        trace.on_end(Data())

        crawler = os.walk(self.tf_dir)
        root = next(crawler)
        self.assertIn('resources', root[1],
                      "A resources subdirectory should have been generated")
        self.assertIn('tf_test.tex', root[2],
                      "The tex file should have been generated")
        # Might be a pdf and/or a .ds_store file depending on system, but shouldn't be more than that
        self.assertLessEqual(len(root[2]), 3,
                             "Extra files should not have been generated")
        figs = next(crawler)
        self.assertIn('tf_test_tfLeNet.pdf', figs[2],
                      "A figure for the model should have been generated")
        self.assertIn('tf_test_logs.png', figs[2],
                      "A log image should have been generated")
        self.assertIn('tf_test.txt', figs[2],
                      "A raw log file should have been generated")
Ejemplo n.º 2
0
    def test_max_to_keep_tf_architecture(self):
        save_dir = tempfile.mkdtemp()
        model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam')
        model_saver = ModelSaver(model=model, save_dir=save_dir, max_to_keep=2, save_architecture=True)
        model_saver.system = sample_system_object()
        model_saver.on_epoch_end(data=Data())
        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
        tf_model_path1 = os.path.join(save_dir, model_name + '.h5')
        tf_architecture_path1 = os.path.join(save_dir, model_name)

        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
        tf_model_path2 = os.path.join(save_dir, model_name + '.h5')
        tf_architecture_path2 = os.path.join(save_dir, model_name)

        with self.subTest('Check only four files are kept'):
            self.assertEqual(len(os.listdir(save_dir)), 4)

        with self.subTest('Check two latest models are kept'):
            self.assertTrue(os.path.exists(tf_model_path1))
            self.assertTrue(os.path.exists(tf_model_path2))
            self.assertTrue(os.path.exists(tf_architecture_path1))
            self.assertTrue(os.path.isdir(tf_architecture_path1))
            self.assertTrue(os.path.exists(tf_architecture_path2))
            self.assertTrue(os.path.isdir(tf_architecture_path2))
Ejemplo n.º 3
0
 def test_on_batch_end(self):
     self.pbm_calibrator.y_true = []
     self.pbm_calibrator.y_pred = []
     batch1 = {
         'y': np.array([0, 0, 1, 1]),
         'y_pred': np.array([[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0,
                                                                  1.0]])
     }
     self.pbm_calibrator.on_batch_end(data=Data(batch1))
     with self.subTest('Check true values'):
         self.assertTrue(
             is_equal(self.pbm_calibrator.y_true, list(batch1['y'])))
     with self.subTest('Check pred values'):
         self.assertTrue(
             is_equal(self.pbm_calibrator.y_pred, list(batch1['y_pred'])))
     batch2 = {
         'y': np.array([1, 1, 0, 0]),
         'y_pred': np.array([[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [1.0,
                                                                  0.0]])
     }
     self.pbm_calibrator.on_batch_end(data=Data(batch2))
     with self.subTest('Check true values (2 batches)'):
         self.assertTrue(
             is_equal(self.pbm_calibrator.y_true,
                      list(batch1['y']) + list(batch2['y'])))
     with self.subTest('Check pred values (2 batches)'):
         self.assertTrue(
             is_equal(self.pbm_calibrator.y_pred,
                      list(batch1['y_pred']) + list(batch2['y_pred'])))
Ejemplo n.º 4
0
 def setUpClass(cls):
     cls.tf_data = Data({
         'x':
         tf.random.normal(shape=(1, 28, 28, 3)),
         'y':
         tf.random.uniform(shape=(1, ), maxval=10, dtype=tf.int32),
         'images':
         tf.random.normal(shape=(1, 28, 28, 3)),
         'embed':
         np.ones(shape=(1, 3, 3, 3)),
         'embed_images':
         np.ones(shape=(1, 3, 3, 3))
     })
     cls.torch_data = Data({
         'x': torch.rand(size=(1, 1, 28, 28)),
         'y': torch.rand(size=(3, )),
         'images': torch.rand(size=(1, 3, 28, 28)),
         'embed': np.ones(shape=(1, 3, 3, 3)),
         'embed_images': np.ones(shape=(1, 3, 3, 3))
     })
     cls.log_dir = os.path.join(tempfile.gettempdir(), 'tensorboard')
     cls.train_path = os.path.join(cls.log_dir, 'train')
     cls.embed_path = os.path.join(cls.log_dir, 'train', '00001', 'embed')
     cls.on_begin_msg = "FastEstimator-Tensorboard: writing logs to {}".format(
         cls.log_dir)
Ejemplo n.º 5
0
 def test_on_epoch_end(self):
     data = Data({})
     test_essential = TestEssential(monitor_names={'loss'})
     test_essential.system = sample_system_object()
     test_essential.test_results['loss'][''].extend([10, 20])
     test_essential.on_epoch_end(data=data)
     self.assertEqual(data['loss'], 15.0)
Ejemplo n.º 6
0
 def test_on_epoch_end(self):
     data = Data({})
     eval_essential = EvalEssential(monitor_names='loss')
     eval_essential.system = sample_system_object()
     eval_essential.eval_results = {'loss': [10, 20]}
     eval_essential.on_epoch_end(data=data)
     self.assertEqual(data['loss'], 15.0)
Ejemplo n.º 7
0
 def setUpClass(cls):
     cls.data = Data({'loss': 10})
     cls.train_essential = TrainEssential(monitor_names='loss')
     cls.train_essential.system = sample_system_object()
     cls.train_essential.system.log_steps = 5
     cls.train_essential.system.global_step = 10
     cls.train_essential.epoch_start = time.perf_counter() - 500
     cls.train_essential.step_start = time.perf_counter() - 300
Ejemplo n.º 8
0
    def _run_traces_on_end(traces: Iterable[Trace]) -> None:
        """Invoke the on_end methods of given traces.

        Args:
            traces: List of traces.
        """
        data = Data()
        for trace in traces:
            trace.on_end(data)
Ejemplo n.º 9
0
    def setUpClass(cls):
        cls.data_np = Data({'loss': np.NaN})
        cls.data_tf = Data({'loss': tf.constant(np.NaN)})
        cls.data_torch = Data({'loss': torch.tensor(np.NaN)})
        cls.expected_msg = "FastEstimator-TerminateOnNaN: NaN Detected in: loss"
        cls.expected_loss_keys = {"ce"}
        cls.expected_all_keys = {"ce", "accuracy", "f1_score"}

        tf_model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam')
        cls.network = fe.Network(ops=[
            ModelOp(model=tf_model, inputs="x", outputs="y"),
            CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
            UpdateOp(model=tf_model, loss_name="ce")
        ])
        cls.traces = [
            Accuracy(true_key="y", pred_key="y_pred", output_name="accuracy"),
            F1Score(true_key="y", pred_key="y_pred", output_name="f1_score")
        ]
Ejemplo n.º 10
0
 def setUpClass(cls):
     cls.data = Data({})
     cls.on_begin_global_step_msg = "FastEstimator-Start: step: 2;"
     cls.on_begin_msg = "FastEstimator-Start: step: 1;"
     cls.on_batch_end_msg = "FastEstimator-Train: step: 1;"
     cls.on_epoch_end_train_msg = "FastEstimator-Train: step: 2; epoch: 0;"
     cls.on_epoch_end_eval_msg = "FastEstimator-Eval: step: 2; epoch: 0;"
     cls.on_epoch_end_test_msg = "FastEstimator-Test: step: 2; epoch: 0;"
     cls.on_end_msg = "FastEstimator-Finish: step: 2;"
Ejemplo n.º 11
0
 def setUpClass(cls):
     cls.save_dir = tempfile.gettempdir()
     cls.tf_model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam', model_name='tf')
     cls.torch_model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam', model_name='torch')
     cls.data = Data({'loss': 0.5})
     cls.state = {'mode': 'train', 'epoch': 1, 'warmup': False, 'deferred': {}, "scaler": None, "tape": None}
     cls.tf_input_data = tf.Variable([[2.0, 1.5, 1.0], [1.0, -1.0, -0.5]])
     cls.tf_y = tf.constant([[-6], [1]])
     cls.torch_input_data = torch.tensor([[1.0, 1.0, 1.0, -0.5], [0.5, 1.0, -1.0, -0.5]], dtype=torch.float32)
     cls.torch_y = torch.tensor([[5], [7]], dtype=torch.float32)
Ejemplo n.º 12
0
    def _run_traces_on_ds_begin(self, traces: Iterable[PerDSTrace]) -> None:
        """Invoke the on_ds_begin methods of given traces.

        Args:
            traces: List of traces.
        """
        data = Data()
        for trace in traces:
            trace.on_ds_begin(data)
        self._check_early_exit()
Ejemplo n.º 13
0
    def _run_traces_on_epoch_end(self, traces: Iterable[Trace]) -> None:
        """Invoke the on_epoch_end methods of of given traces.

        Args:
            traces: List of traces.
        """
        data = Data()
        for trace in traces:
            trace.on_epoch_end(data)
        self._check_early_exit()
Ejemplo n.º 14
0
    def _run_traces_on_batch_begin(self, batch: Dict[str, Any], traces: Iterable[Trace]) -> None:
        """Invoke the on_batch_begin methods of given traces.

        Args:
            batch: The batch data which was provided by the pipeline.
            traces: List of traces.
        """
        data = Data(batch)
        for trace in traces:
            trace.on_batch_begin(data)
        self._check_early_exit()
Ejemplo n.º 15
0
 def setUpClass(cls):
     cls.image_dir = tempfile.gettempdir()
     cls.image_path = os.path.join(cls.image_dir,
                                   'img_train_epoch_0_elem_0.png')
     cls.img_data_path = os.path.join(cls.image_dir,
                                      'img_data_train_epoch_0.png')
     cls.input_img = 0.5 * np.ones((1, 32, 32, 3))
     cls.mask = np.zeros_like(cls.input_img)
     cls.mask[0, 10:20, 10:30, :] = [1, 0, 0]
     bbox = np.array([[[3, 7, 10, 6, 'box1'], [20, 20, 8, 8, 'box2']]] * 1)
     d = ImgData(y=np.ones((1, )), x=[cls.input_img, cls.mask, bbox])
     cls.data = Data({'img': cls.input_img, 'img_data': d})
Ejemplo n.º 16
0
    def test_restore(self):
        save_path = tempfile.mkdtemp()
        global_step = 100
        epoch_idx = 10

        restore_wizard = RestoreWizard(directory=save_path)
        restore_wizard.system = sample_system_object()
        restore_wizard.on_begin(Data())
        restore_wizard.system.global_step = global_step
        restore_wizard.system.epoch_idx = epoch_idx
        restore_wizard.on_epoch_end(Data())

        restore_wizard = RestoreWizard(directory=save_path)
        restore_wizard.system = sample_system_object()
        data = Data()
        restore_wizard.on_begin(data)
        with self.subTest("Check print message"):
            self.assertEqual(data['epoch'], 10)
        with self.subTest("Check system variables"):
            self.assertEqual(restore_wizard.system.global_step, global_step)
            self.assertEqual(restore_wizard.system.epoch_idx, epoch_idx)
Ejemplo n.º 17
0
    def test_max_to_keep_torch(self):
        save_dir = tempfile.mkdtemp()
        model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam')
        model_saver = ModelSaver(model=model, save_dir=save_dir, max_to_keep=2)
        model_saver.system = sample_system_object()
        model_saver.on_epoch_end(data=Data())
        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
        torch_model_path1 = os.path.join(save_dir, model_name + '.pt')

        model_saver.system.epoch_idx += 1
        model_saver.on_epoch_end(data=Data())
        model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
        torch_model_path2 = os.path.join(save_dir, model_name + '.pt')

        with self.subTest('Check only two file are kept'):
            self.assertEqual(len(os.listdir(save_dir)), 2)

        with self.subTest('Check two latest model are kept'):
            self.assertTrue(os.path.exists(torch_model_path1))
            self.assertTrue(os.path.exists(torch_model_path2))
Ejemplo n.º 18
0
 def test_tf_model(self):
     model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam')
     model_saver = ModelSaver(model=model, save_dir=self.save_dir)
     model_saver.system = sample_system_object()
     model_saver.on_epoch_end(data=Data())
     model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
     tf_model_path = os.path.join(self.save_dir, model_name + '.h5')
     with self.subTest('Check if model is saved'):
         self.assertTrue(os.path.exists(tf_model_path))
     with self.subTest('Validate model weights'):
         m2 = fe.build(model_fn=one_layer_model_without_weights, optimizer_fn='adam')
         fe.backend.load_model(m2, tf_model_path)
         self.assertTrue(is_equal(m2.trainable_variables, model.trainable_variables))
Ejemplo n.º 19
0
    def _run_traces_on_batch_end(self, batch: Dict[str, Any], prediction: Dict[str, Any],
                                 traces: Iterable[Trace]) -> None:
        """Invoke the on_batch_end methods of given traces.

        Args:
            batch: The batch data which was provided by the pipeline.
            prediction: The prediction data which was generated by the network.
            traces: List of traces.
        """
        data = Data(ChainMap(prediction, batch))
        for trace in traces:
            trace.on_batch_end(data)
        self._check_early_exit()
Ejemplo n.º 20
0
 def test_save(self):
     save_path = tempfile.mkdtemp()
     restore_wizard = RestoreWizard(directory=save_path)
     restore_wizard.system = sample_system_object()
     restore_wizard.on_begin(Data())
     restore_wizard.on_epoch_end(Data())
     with self.subTest("Check Saved Files (1)"):
         self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt')))
         self.assertTrue(os.path.exists(os.path.join(save_path, 'A')))
     with self.subTest("Check Key is Correct (1)"):
         with open(os.path.join(save_path, 'key.txt'), 'r') as file:
             key = file.readline()
             self.assertEqual(key, "A")
     restore_wizard.on_epoch_end(Data())
     with self.subTest("Check Saved Files (2)"):
         self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt')))
         self.assertTrue(os.path.exists(os.path.join(save_path, 'B')))
     with self.subTest("Check Key is Correct (2)"):
         with open(os.path.join(save_path, 'key.txt'), 'r') as file:
             key = file.readline()
             self.assertEqual(key, "B")
     restore_wizard.on_epoch_end(Data())
     with self.subTest("Check Saved Files (3)"):
         self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt')))
         self.assertTrue(os.path.exists(os.path.join(save_path, 'A')))
     with self.subTest("Check Key is Correct (3)"):
         with open(os.path.join(save_path, 'key.txt'), 'r') as file:
             key = file.readline()
             self.assertEqual(key, "A")
     restore_wizard.on_epoch_end(Data())
     with self.subTest("Check Saved Files (4)"):
         self.assertTrue(os.path.exists(os.path.join(save_path, 'key.txt')))
         self.assertTrue(os.path.exists(os.path.join(save_path, 'B')))
     with self.subTest("Check Key is Correct (4)"):
         with open(os.path.join(save_path, 'key.txt'), 'r') as file:
             key = file.readline()
             self.assertEqual(key, "B")
Ejemplo n.º 21
0
 def test_torch_model(self):
     model = fe.build(model_fn=MultiLayerTorchModel, optimizer_fn='adam')
     model_saver = ModelSaver(model=model, save_dir=self.save_dir)
     model_saver.system = sample_system_object()
     model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx)
     torch_model_path = os.path.join(self.save_dir, model_name + '.pt')
     if os.path.exists(torch_model_path):
         os.remove(torch_model_path)
     model_saver.on_epoch_end(data=Data())
     with self.subTest('Check if model is saved'):
         self.assertTrue(os.path.exists(torch_model_path))
     with self.subTest('Validate model weights'):
         m2 = fe.build(model_fn=MultiLayerTorchModelWithoutWeights, optimizer_fn='adam')
         fe.backend.load_model(m2, torch_model_path)
         self.assertTrue(is_equal(list(m2.parameters()), list(model.parameters())))
Ejemplo n.º 22
0
    def _run_traces_on_end(traces: Iterable[Trace]) -> None:
        """Invoke the on_end methods of given traces.

        Args:
            traces: List of traces.
        """
        data = Data()
        traceability = None
        for trace in traces:
            if isinstance(trace, Traceability):
                # Delay traceability until the end so that it can capture all data including the total training time
                traceability = trace
                continue
            trace.on_end(data)
        if traceability:
            traceability.on_end(data)
Ejemplo n.º 23
0
 def test_on_epoch_end(self):
     self.pbm_calibrator.y_true = [0] * 50 + [1] * 50
     self.pbm_calibrator.y_pred = list(
         np.array([1.0, 0.0] * 50 + [0.0, 1.0] * 50).reshape(100, 2))
     expected = np.array([1.0, 0.0] * 50 + [0.0, 1.0] * 50).reshape(100, 2)
     data = Data()
     self.pbm_calibrator.on_epoch_end(data=data)
     with self.subTest('Check if output exists'):
         self.assertIn('y_pred_calibrated', data)
     with self.subTest('Check save file exists'):
         self.assertTrue(os.path.exists(self.save_file))
     with self.subTest('Check the calibrated values'):
         self.assertTrue(np.allclose(data['y_pred_calibrated'], expected))
     with self.subTest('Check the save file performance'):
         with open(self.save_file, 'rb') as f:
             fn = dill.load(f)
             resp = fn(expected)
             self.assertTrue(np.array_equal(resp,
                                            data['y_pred_calibrated']))
Ejemplo n.º 24
0
    def _run_traces_on_begin(self, traces: Iterable[Trace]) -> None:
        """Invoke the on_begin methods of given traces.

        Args:
            traces: List of traces.
        """
        data = Data()
        restore = None
        for trace in traces:
            # Delay RestoreWizard until the end so that it can overwrite everyone's on_begin methods
            if isinstance(trace, RestoreWizard):
                restore = trace
                continue
            # Restore does need to run before the logger though
            if isinstance(trace, Logger) and restore:
                restore.on_begin(data)
                restore = None
            trace.on_begin(data)
        if restore:
            restore.on_begin(data)
        self._check_early_exit()
Ejemplo n.º 25
0
    def run_trace(self) -> None:
        self.data_on_begin = Data()
        self.trace.on_begin(self.data_on_begin)

        self.data_on_epoch_begin = Data()
        self.trace.on_epoch_begin(self.data_on_epoch_begin)

        self.data_on_batch_begin = Data(self.batch)
        self.trace.on_batch_begin(self.data_on_batch_begin)

        self.data_on_batch_end = Data(ChainMap(self.prediction, self.batch))
        self.trace.on_batch_end(self.data_on_batch_end)

        self.data_on_epoch_end = Data()
        self.trace.on_epoch_end(self.data_on_epoch_end)

        self.data_on_end = Data()
        self.trace.on_end(self.data_on_end)
Ejemplo n.º 26
0
    def run_trace(self) -> None:
        system = sample_system_object()
        self.trace.system = system

        self.data_on_begin = Data()
        self.trace.on_begin(self.data_on_begin)

        self.data_on_epoch_begin = Data()
        self.trace.on_epoch_begin(self.data_on_epoch_begin)

        self.data_on_batch_begin = Data(self.batch)
        self.trace.on_batch_begin(self.data_on_batch_begin)

        self.data_on_batch_end = Data(ChainMap(self.prediction, self.batch))
        self.trace.on_batch_end(self.data_on_batch_end)

        self.data_on_epoch_end = Data()
        self.trace.on_epoch_end(self.data_on_epoch_end)

        self.data_on_end = Data()
        self.trace.on_end(self.data_on_end)
Ejemplo n.º 27
0
 def setUpClass(cls):
     cls.data = Data({'loss': 10})
Ejemplo n.º 28
0
 def test_on_batch_end_eval_results_none(self):
     data = Data({'loss': 5})
     eval_essential = EvalEssential(monitor_names='loss')
     eval_essential.system = sample_system_object()
     eval_essential.on_batch_end(data=data)
     self.assertEqual(eval_essential.eval_results['loss'], [5])
Ejemplo n.º 29
0
 def setUpClass(cls):
     cls.data = Data({})
Ejemplo n.º 30
0
    def _run_epoch(self, eager: bool) -> None:
        """A method to perform an epoch of activity.

        This method requires that the current mode and epoch already be specified within the self.system object.

        Args:
            eager: Whether to run the training in eager mode. This is only related to TensorFlow training because
                PyTorch by nature is always in eager mode.
        """
        ds_ids = self.pipeline.get_ds_ids(self.system.epoch_idx,
                                          self.system.mode)
        epoch_traces = sort_traces(get_current_items(
            self.traces_in_use,
            run_modes=self.system.mode,
            epoch=self.system.epoch_idx),
                                   ds_ids=ds_ids)
        self._run_traces_on_epoch_begin(traces=epoch_traces)
        self.system.batch_idx = None
        end_epoch_data = Data(
        )  # We will aggregate data over on_ds_end and put it into on_epoch_end for printing
        # run for each dataset
        for self.system.ds_id in ds_ids:
            ds_traces = get_current_items(self.traces_in_use,
                                          run_modes=self.system.mode,
                                          epoch=self.system.epoch_idx,
                                          ds_id=self.system.ds_id)
            trace_input_keys = set()
            for ds_trace in ds_traces:
                trace_input_keys.update(ds_trace.inputs)
            network_input_keys = self.network.get_effective_input_keys(
                mode=self.system.mode,
                epoch=self.system.epoch_idx,
                ds_id=self.system.ds_id)
            network_output_keys = self.network.get_all_output_keys(
                mode=self.system.mode,
                epoch=self.system.epoch_idx,
                ds_id=self.system.ds_id)
            self.network.load_epoch(mode=self.system.mode,
                                    epoch=self.system.epoch_idx,
                                    ds_id=self.system.ds_id,
                                    output_keys=trace_input_keys,
                                    eager=eager)

            with self.pipeline(
                    mode=self.system.mode,
                    epoch=self.system.epoch_idx,
                    ds_id=self.system.ds_id,
                    steps_per_epoch=self.system.steps_per_epoch,
                    output_keys=trace_input_keys - network_output_keys
                    | network_input_keys) as loader:
                loader = self._configure_loader(loader)
                iterator = iter(loader)
                with Suppressor():
                    batch = next(iterator)
                ds_traces = sort_traces(ds_traces,
                                        available_outputs=to_set(batch.keys())
                                        | network_output_keys,
                                        ds_ids=ds_ids)
                per_ds_traces = [
                    trace for trace in ds_traces
                    if isinstance(trace, PerDSTrace)
                ]
                self._run_traces_on_ds_begin(traces=per_ds_traces)
                while True:
                    try:
                        if self.system.mode == "train":
                            self.system.update_global_step()
                        self.system.update_batch_idx()
                        batch = self._configure_tensor(loader, batch)
                        self._run_traces_on_batch_begin(batch,
                                                        traces=ds_traces)
                        batch, prediction = self.network.run_step(batch)
                        self._run_traces_on_batch_end(batch,
                                                      prediction,
                                                      traces=ds_traces)
                        if isinstance(loader, DataLoader) and (
                            (self.system.batch_idx
                             == self.system.train_steps_per_epoch
                             and self.system.mode == "train") or
                            (self.system.batch_idx
                             == self.system.eval_steps_per_epoch
                             and self.system.mode == "eval")):
                            raise StopIteration
                        with Suppressor():
                            batch = next(iterator)
                    except StopIteration:
                        break
                self._run_traces_on_ds_end(traces=per_ds_traces,
                                           data=end_epoch_data)
            self.network.unload_epoch()
        self._run_traces_on_epoch_end(traces=epoch_traces, data=end_epoch_data)