def test_use_external_data_format(self): """ External data format is required only if the serialized size of the parameters if bigger than 2Gb """ TWO_GB_LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT # No parameters self.assertFalse(OnnxConfig.use_external_data_format(0)) # Some parameters self.assertFalse(OnnxConfig.use_external_data_format(1)) # Almost 2Gb parameters self.assertFalse(OnnxConfig.use_external_data_format((TWO_GB_LIMIT - 1) // ParameterFormat.Float.size)) # Exactly 2Gb parameters self.assertTrue(OnnxConfig.use_external_data_format(TWO_GB_LIMIT)) # More than 2Gb parameters self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))
def test_flatten_output_collection_property(self): """ This test ensures we correctly flatten nested collection such as the one we use when returning past_keys. past_keys = Tuple[Tuple] ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n} """ self.assertEqual( OnnxConfig.flatten_output_collection_property("past_key", [[0], [1], [2]]), { "past_key.0": 0, "past_key.1": 1, "past_key.2": 2, }, )