def test_base_attrs(self): node = self._create_node() RNNFrontExtractor.extract(node) exp_res = self.base_attrs for key in exp_res.keys(): equal = np.all(np.equal(node[key], exp_res[key], dtype=object)) self.assertTrue(equal)
def test_additional_attributes(self): additional_attrs = { 'activation_alpha': [1.0, 0.0, 2.0], 'activations': [b'relu', b'tanh', b'sigmoid'], 'clip': 10.0, } node = self._create_node(**additional_attrs) RNNFrontExtractor.extract(node) exp_res = {**self.base_attrs, **additional_attrs} exp_res['activations'] = ['relu', 'tanh', 'sigmoid'] for key in exp_res.keys(): equal = np.all(np.equal(node[key], exp_res[key], dtype=object)) self.assertTrue(equal, 'Values for attr {} are not equal'.format(key))