Exemplo n.º 1
0
 def test_training_end(self, *args):
     """Test the end function in TrainLineage."""
     args[0].return_value = 64
     train_callback = TrainLineage(SUMMARY_DIR, True,
                                   self.user_defined_info)
     train_callback.initial_learning_rate = 0.12
     train_callback.end(RunContext(self.run_context))
     res = get_summary_lineage(summary_dir=SUMMARY_DIR)
     assert res.get('hyper_parameters', {}).get('epoch') == 10
     run_context = self.run_context
     run_context['epoch_num'] = 14
     train_callback.end(RunContext(run_context))
     res = get_summary_lineage(summary_dir=SUMMARY_DIR)
     assert res.get('hyper_parameters', {}).get('epoch') == 14
Exemplo n.º 2
0
 def test_train_with_customized_network(self, *args):
     """Test train with customized network."""
     args[0].return_value = 64
     train_callback = TrainLineage(SUMMARY_DIR, True,
                                   self.user_defined_info)
     run_context_customized = self.run_context
     del run_context_customized['optimizer']
     del run_context_customized['net_outputs']
     del run_context_customized['loss_fn']
     net = WithLossCell(self.net, self.loss_fn)
     net_cap = net
     net_cap._cells = {'_backbone': self.net, '_loss_fn': self.loss_fn}
     net = TrainOneStep(net, self.optimizer)
     net._cells = {
         'optimizer': self.optimizer,
         'network': net_cap,
         'backbone': self.net
     }
     run_context_customized['train_network'] = net
     train_callback.begin(RunContext(run_context_customized))
     train_callback.end(RunContext(run_context_customized))
     res = get_summary_lineage(summary_dir=SUMMARY_DIR)
     assert res.get('hyper_parameters', {}).get('loss_function') \
            == 'SoftmaxCrossEntropyWithLogits'
     assert res.get('algorithm', {}).get('network') == 'ResNet'
     assert res.get('hyper_parameters', {}).get('optimizer') == 'Momentum'
Exemplo n.º 3
0
 def test_get_summary_lineage(self):
     """Test the interface of get_summary_lineage."""
     total_res = get_summary_lineage(data_manager=self._data_manger,
                                     summary_dir="./run1")
     expect_total_res = LINEAGE_INFO_RUN1
     assert_equal_lineages(expect_total_res, total_res,
                           self.assertDictEqual)
Exemplo n.º 4
0
    def test_get_summary_lineage(self):
        """Test the interface of get_summary_lineage."""
        total_res = get_summary_lineage(None, SUMMARY_DIR)
        partial_res1 = get_summary_lineage(None, SUMMARY_DIR,
                                           ['hyper_parameters'])
        partial_res2 = get_summary_lineage(None, SUMMARY_DIR,
                                           ['metric', 'algorithm'])
        expect_total_res = LINEAGE_INFO_RUN1
        expect_partial_res1 = {
            'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
            'hyper_parameters': {
                'optimizer': 'Momentum',
                'learning_rate': 0.12,
                'loss_function': 'SoftmaxCrossEntropyWithLogits',
                'epoch': 14,
                'parallel_mode': 'stand_alone',
                'device_num': 2,
                'batch_size': 32
            }
        }
        expect_partial_res2 = {
            'summary_dir': os.path.join(BASE_SUMMARY_DIR, 'run1'),
            'metric': {
                'accuracy': 0.78
            },
            'algorithm': {
                'network': 'ResNet'
            }
        }
        assert_equal_lineages(expect_total_res, total_res,
                              self.assertDictEqual)
        assert_equal_lineages(expect_partial_res1, partial_res1,
                              self.assertDictEqual)
        assert_equal_lineages(expect_partial_res2, partial_res2,
                              self.assertDictEqual)

        # the lineage summary file is empty
        result = get_summary_lineage(None, self.dir_with_empty_lineage)
        assert {} == result

        # keys is empty list
        expect_result = {'summary_dir': SUMMARY_DIR}
        result = get_summary_lineage(None, SUMMARY_DIR, [])
        assert expect_result == result
Exemplo n.º 5
0
 def test_get_summary_lineage_failed3(self, mock_summary, mock_valid,
                                      mock_parser, mock_file_handler):
     """Test get_summary_lineage failed."""
     mock_summary.return_value = ['/path/to/summary/file']
     mock_valid.return_value = '/path/to/summary_dir'
     mock_parser.return_value = None
     mock_file_handler = MagicMock()
     mock_file_handler.size = 1
     result = get_summary_lineage(None, '/path/to/summary_dir')
     assert {} == result
Exemplo n.º 6
0
    def test_get_summary_lineage_success(self, isdir_mock, parser_mock,
                                         qurier_mock):
        """Test the function of get_summary_lineage."""
        isdir_mock.return_value = True
        parser_mock.return_value = MagicMock()

        mock_querier = MagicMock()
        qurier_mock.return_value = mock_querier
        mock_querier.get_summary_lineage.return_value = [{
            'algorithm': {
                'network': 'ResNet'
            }
        }]
        summary_dir = '/path/to/summary_dir'
        result = get_summary_lineage(None, summary_dir, keys=['algorithm'])
        self.assertEqual(result, {'algorithm': {'network': 'ResNet'}})
Exemplo n.º 7
0
def get_dataset_graph():
    """
    Get dataset graph.

    Returns:
        str, the dataset graph information.

    Raises:
        MindInsightException: If method fails to be called.
        ParamValueError: If summary_dir is invalid.

    Examples:
        >>> GET http://xxxx/v1/mindinsight/datasets/dataset_graph?train_id=xxx
    """

    summary_base_dir = str(settings.SUMMARY_BASE_DIR)
    summary_dir = get_train_id(request)
    try:
        dataset_graph = get_summary_lineage(
            DATA_MANAGER,
            summary_dir=summary_dir,
            keys=['dataset_graph']
        )
    except MindInsightException as exception:
        raise MindInsightException(exception.error, exception.message, http_code=400)

    if dataset_graph:
        summary_dir_result = dataset_graph.get('summary_dir')
        base_dir_len = len(summary_base_dir)
        if summary_base_dir == summary_dir_result:
            relative_dir = './'
        else:
            relative_dir = os.path.join(
                os.curdir, summary_dir[base_dir_len + 1:]
            )
        dataset_graph['summary_dir'] = relative_dir

    return jsonify(dataset_graph)