Exemplo n.º 1
0
    def test_initialize_parameters_deep_he(self):
        """
        test ml.common.parameters :: Parameters :: initialize_parameters_deep_he
        @return:
        """
        from ml.common.parameters import Parameters

        obj = Parameters('does not exist')
        self.assertEqual(obj.base_path, self.base_path)

        obj = Parameters(self.test_path, 'file name')
        self.assertEqual(obj.base_path, self.test_path)

        tests = [{
            "ld": [1, 2, 3],
            "parameters": {
                    "W1": [[2.2971712432704137], [-0.8651542170526618]],
                    "b1": [[0.], [0.]],
                    "W2": [[-0.5281717522634557, -1.0729686221561705],
                           [0.8654076293246785, -2.3015386968802827],
                           [1.74481176421648, -0.7612069008951028]],
                    "b2": [[0.], [0.], [0.]]},
        }]
        for test in tests:
            layer_dims = test["ld"]
            expected_parameters = test["parameters"]
            obj.initialize_parameters_deep_he(layer_dims)
            for i in range(2):
                obj._parameters["W" + str(i + 1)] = obj._parameters["W" + str(i + 1)].tolist()
                obj._parameters["b" + str(i + 1)] = obj._parameters["b" + str(i + 1)].tolist()
            self.assertDictEqual(obj._parameters, expected_parameters)
Exemplo n.º 2
0
    def test_load(self, mock_np):
        """
        test ml.common.parameters :: Parameters :: load
        """
        from ml.common.parameters import Parameters

        obj = Parameters('does not exist')
        self.assertEqual(obj.base_path, self.base_path)

        obj = Parameters(self.test_path, 'file name')
        self.assertEqual(obj.base_path, self.test_path)

        mock_np_load_result = MagicMock()
        mock_np_load_result.item.return_value = 'np load'
        mock_np.load.return_value = mock_np_load_result

        obj.load()
        param_file = os.path.join(self.test_path, 'datasets', 'file name')
        mock_np.load.assert_called_with(param_file)
        self.assertEqual(obj._parameters, 'np load')
        pass