Exemple #1
0
 def test_collect_histogram_from_regular(self, mock_add_value,
                                         histogram_regular, expected_names,
                                         expected_values):
     """Test collect histogram from regular success."""
     mock_add_value.side_effect = add_value
     cb_params = _InternalCallbackParam()
     parameters = [
         Parameter(Tensor(1), 'conv1.weight1'),
         Parameter(Tensor(2), 'conv2.weight2'),
         Parameter(Tensor(3), 'conv1.bias1'),
         Parameter(Tensor(4), 'conv3.bias'),
         Parameter(Tensor(5), 'conv5.bias'),
         Parameter(Tensor(6), 'conv6.bias'),
     ]
     cb_params.optimizer = Optimizer(learning_rate=0.1,
                                     parameters=parameters)
     with SummaryCollector((tempfile.mkdtemp(
             dir=self.base_summary_dir))) as summary_collector:
         summary_collector._collect_specified_data[
             'histogram_regular'] = histogram_regular
         summary_collector._collect_histogram(cb_params)
     result = get_value()
     assert PluginEnum.HISTOGRAM.value == result[0][0]
     assert expected_names == [data[1] for data in result]
     assert expected_values == [data[2] for data in result]
Exemple #2
0
    def test_get_optimizer_from_cb_params_success(self):
        """Test get optimizer success from cb params."""
        cb_params = _InternalCallbackParam()
        cb_params.optimizer = Optimizer(learning_rate=0.1, parameters=[Parameter(Tensor(1), 'weight')])
        summary_collector = SummaryCollector((tempfile.mkdtemp(dir=self.base_summary_dir)))
        optimizer = summary_collector._get_optimizer(cb_params)
        assert optimizer == cb_params.optimizer

        # Test get optimizer again
        assert summary_collector._get_optimizer(cb_params) == cb_params.optimizer
Exemple #3
0
 def __init__(self):
     super(CustomNet, self).__init__()
     self.add = TensorAdd
     self.optimizer = Optimizer(learning_rate=1,
                                parameters=[Parameter(Tensor(1), 'weight')])