Example #1
0
 def test_begin(self, *args):
     """Test TrainLineage.begin method."""
     args[1].return_value = None
     args[2].return_value = Optimizer(Tensor(0.1))
     args[3].return_value = None
     args[5].serialize.return_value = {}
     run_context = {'optimizer': Optimizer(Tensor(0.1)), 'epoch_num': 10}
     train_lineage = self.my_train_module(
         self.my_summary_record(self.summary_log_path))
     train_lineage.begin(self.my_run_context(run_context))
     args[4].assert_called()
Example #2
0
 def test_get_loss_fn_by_network(self, mock_vars):
     """Test get_loss_fn_by_network."""
     mock_cell1 = {'_cells': {'key': SoftmaxCrossEntropyWithLogits(0.2)}}
     mock_cell2 = {'_cells': {'opt': Optimizer(Tensor(0.1))}}
     mock_cell3 = {'_cells': {'loss': SoftmaxCrossEntropyWithLogits(0.1)}}
     mock_vars.side_effect = [mock_cell1, mock_cell2, mock_cell3]
     res = AnalyzeObject.get_loss_fn_by_network(MagicMock())
     self.assertEqual(res, mock_cell3['_cells']['loss'])
Example #3
0
 def test_get_optimizer_by_network(self, mock_vars):
     """Test get_optimizer_by_network."""
     mock_optimizer = Optimizer(Tensor(0.1))
     mock_cells = MagicMock()
     mock_cells.items.return_value = [{'key': mock_optimizer}]
     mock_vars.return_value = {'_cells': {'key': mock_optimizer}}
     res = AnalyzeObject.get_optimizer_by_network(MagicMock())
     self.assertEqual(res, mock_optimizer)
Example #4
0
 def test_begin_error(self, *args):
     """Test TrainLineage.begin method."""
     args[1].return_value = None
     args[2].return_value = Optimizer(Tensor(0.1))
     args[3].return_value = None
     args[4].side_effect = Exception
     args[5].serialize.return_value = {}
     run_context = {'optimizer': Optimizer(Tensor(0.1)), 'epoch_num': 10}
     train_lineage = self.my_train_module(
         self.my_summary_record(self.summary_log_path), True)
     with self.assertRaisesRegex(LineageLogError,
                                 'Dataset graph log error'):
         train_lineage.begin(self.my_run_context(run_context))
     train_lineage = self.my_train_module(
         self.my_summary_record(self.summary_log_path))
     train_lineage.begin(self.my_run_context(run_context))
     args[4].assert_called()
Example #5
0
 def test_analyze_optimizer(self):
     """Test analyze_optimizer method."""
     optimizer = Optimizer(Tensor(0.12))
     res = self.analyzer.analyze_optimizer(optimizer)
     assert res == 0.12