def test_begin(self, *args): """Test TrainLineage.begin method.""" args[1].return_value = None args[2].return_value = Optimizer(Tensor(0.1)) args[3].return_value = None args[5].serialize.return_value = {} run_context = {'optimizer': Optimizer(Tensor(0.1)), 'epoch_num': 10} train_lineage = self.my_train_module( self.my_summary_record(self.summary_log_path)) train_lineage.begin(self.my_run_context(run_context)) args[4].assert_called()
def test_get_loss_fn_by_network(self, mock_vars): """Test get_loss_fn_by_network.""" mock_cell1 = {'_cells': {'key': SoftmaxCrossEntropyWithLogits(0.2)}} mock_cell2 = {'_cells': {'opt': Optimizer(Tensor(0.1))}} mock_cell3 = {'_cells': {'loss': SoftmaxCrossEntropyWithLogits(0.1)}} mock_vars.side_effect = [mock_cell1, mock_cell2, mock_cell3] res = AnalyzeObject.get_loss_fn_by_network(MagicMock()) self.assertEqual(res, mock_cell3['_cells']['loss'])
def test_get_optimizer_by_network(self, mock_vars): """Test get_optimizer_by_network.""" mock_optimizer = Optimizer(Tensor(0.1)) mock_cells = MagicMock() mock_cells.items.return_value = [{'key': mock_optimizer}] mock_vars.return_value = {'_cells': {'key': mock_optimizer}} res = AnalyzeObject.get_optimizer_by_network(MagicMock()) self.assertEqual(res, mock_optimizer)
def test_begin_error(self, *args): """Test TrainLineage.begin method.""" args[1].return_value = None args[2].return_value = Optimizer(Tensor(0.1)) args[3].return_value = None args[4].side_effect = Exception args[5].serialize.return_value = {} run_context = {'optimizer': Optimizer(Tensor(0.1)), 'epoch_num': 10} train_lineage = self.my_train_module( self.my_summary_record(self.summary_log_path), True) with self.assertRaisesRegex(LineageLogError, 'Dataset graph log error'): train_lineage.begin(self.my_run_context(run_context)) train_lineage = self.my_train_module( self.my_summary_record(self.summary_log_path)) train_lineage.begin(self.my_run_context(run_context)) args[4].assert_called()
def test_analyze_optimizer(self): """Test analyze_optimizer method.""" optimizer = Optimizer(Tensor(0.12)) res = self.analyzer.analyze_optimizer(optimizer) assert res == 0.12