def test_additional_attrs(self): attrs = { 'state_size': 128, 'mode': 'lstm', 'bidirectional': True, 'num_layers': 2, } additional_attrs = { 'multilayers': True, 'hidden_size': 128, 'has_num_directions': True, 'direction': 'bidirectional', 'num_layers': 2, } node = self._create_node(**attrs) RNNFrontExtractor.extract(node) expect_attrs = {**self.base_attrs, **additional_attrs} for key in expect_attrs.keys(): equal = np.all(np.equal(node[key], expect_attrs[key], dtype=object)) self.assertTrue(equal, 'Values for attr {} are not equal'.format(key))
def test_unsupported_mode(self): attrs = { 'state_size': 128, 'mode': 'abracadabra', } node = self._create_node(**attrs) with self.assertRaises(Error): RNNFrontExtractor.extract(node)