def test_train_with_customized_network(self, *args):
     """Test train with customized network."""
     args[0].return_value = 64
     train_callback = TrainLineage(SUMMARY_DIR, True,
                                   self.user_defined_info)
     run_context_customized = self.run_context
     del run_context_customized['optimizer']
     del run_context_customized['net_outputs']
     del run_context_customized['loss_fn']
     net = WithLossCell(self.net, self.loss_fn)
     net_cap = net
     net_cap._cells = {'_backbone': self.net, '_loss_fn': self.loss_fn}
     net = TrainOneStep(net, self.optimizer)
     net._cells = {
         'optimizer': self.optimizer,
         'network': net_cap,
         'backbone': self.net
     }
     run_context_customized['train_network'] = net
     train_callback.begin(RunContext(run_context_customized))
     train_callback.end(RunContext(run_context_customized))
     res = get_summary_lineage(summary_dir=SUMMARY_DIR)
     assert res.get('hyper_parameters', {}).get('loss_function') \
            == 'SoftmaxCrossEntropyWithLogits'
     assert res.get('algorithm', {}).get('network') == 'ResNet'
     assert res.get('hyper_parameters', {}).get('optimizer') == 'Momentum'
    def test_train_with_customized_network(self, *args):
        """Test train with customized network."""
        args[0].return_value = 64
        train_callback = TrainLineage(SUMMARY_DIR, True, self.user_defined_info)
        run_context_customized = self.run_context
        del run_context_customized['optimizer']
        del run_context_customized['net_outputs']
        del run_context_customized['loss_fn']
        net = WithLossCell(self.net, self.loss_fn)
        net_cap = net
        net_cap._cells = {'_backbone': self.net,
                          '_loss_fn': self.loss_fn}
        net = TrainOneStep(net, self.optimizer)
        net._cells = {'optimizer': self.optimizer,
                      'network': net_cap,
                      'backbone': self.net}
        run_context_customized['train_network'] = net
        train_callback.begin(RunContext(run_context_customized))
        train_callback.end(RunContext(run_context_customized))

        LINEAGE_DATA_MANAGER.start_load_data().join()
        res = filter_summary_lineage(data_manager=LINEAGE_DATA_MANAGER, search_condition=self._search_condition)
        assert res.get('object')[0].get('model_lineage', {}).get('loss_function') \
               == 'SoftmaxCrossEntropyWithLogits'
        assert res.get('object')[0].get('model_lineage', {}).get('network') == 'ResNet'
        assert res.get('object')[0].get('model_lineage', {}).get('optimizer') == 'Momentum'