def test_batch_size_has_no_effect_on_cost(self, modelclass, input_shape_per_sample, model_kwargs): expected_compute_cost = None expected_memory_cost = None batch_size_list = [32, 64, 128, 256, 512, 1024] module = modelclass() # Sweep over the batch size list for batch_size in batch_size_list: input_shape = (batch_size, ) + input_shape_per_sample init_state = module.init(random.PRNGKey(0), jnp.ones(input_shape, jnp.float32), **model_kwargs) hlo_proto = hlo_utils.load_hlo_proto_from_model( module, init_state, [input_shape], **model_kwargs) del init_state compute_result = compute_cost_utils.estimate_compute_cost( hlo_proto) memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto) # Save the first cost and compare it with the rest if expected_compute_cost is None: expected_compute_cost = compute_result['compute_cost'] else: self.assertEqual(compute_result['compute_cost'], expected_compute_cost) if expected_memory_cost is None: expected_memory_cost = memory_result['memory_cost'] else: self.assertEqual(memory_result['memory_cost'], expected_memory_cost)
def test_estimate_simple_model_cost( self, modelclass, input_shapes, model_kwargs, expected_compute_cost, expected_compute_cost_ratio, expected_compute_cost_linear, expected_compute_cost_ratio_linear, expected_memory_cost, expected_memory_cost_ratio): module = modelclass() input_shapes_with_type = [(sh, jnp.float32) for sh in input_shapes] dummy_inputs = [ jnp.ones(input_shape, dtype=dtype) for (input_shape, dtype) in input_shapes_with_type ] init_state = module.init(random.PRNGKey(0), *dummy_inputs, **model_kwargs) hlo_proto = hlo_utils.load_hlo_proto_from_model( module, init_state, input_shapes, **model_kwargs) compute_result = compute_cost_utils.estimate_compute_cost(hlo_proto) memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto) logging.info('compute cost result is %s', compute_result) logging.info('memory cost result is %s', memory_result) self.assertEqual(compute_result['compute_cost'], expected_compute_cost) self.assertEqual(memory_result['memory_cost'], expected_memory_cost) self.assertEqual(compute_result['compute_cost_ratio_to_bfloat16'], expected_compute_cost_ratio) self.assertEqual(memory_result['memory_cost_ratio_to_bfloat16'], expected_memory_cost_ratio) self.assertEqual(compute_result['compute_cost_linear'], expected_compute_cost_linear) self.assertEqual( compute_result['compute_cost_ratio_to_bfloat16_linear'], expected_compute_cost_ratio_linear)
def test_estimate_resnet_cost(self, base_config_filename, expected_compute_cost, expected_compute_cost_ratio, expected_memory_cost, expected_memory_cost_ratio): batch_size = 1024 image_size = 224 input_channels = 3 input_shape = (batch_size, image_size, image_size, input_channels) logging.info('Testing for %s...', base_config_filename) hparams = hparams_utils.load_hparams_from_config_dict( hparams_config.TrainingHParams, models.ResNet.HParams, base_config_filename.get_config()) hlo_proto = self._create_hlo_from_resnet_hparams(hparams, input_shape) compute_result = compute_cost_utils.estimate_compute_cost(hlo_proto) memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto) self.assertEqual(compute_result['compute_cost'], expected_compute_cost) self.assertAlmostEqual( compute_result['compute_cost_ratio_to_bfloat16'], expected_compute_cost_ratio) self.assertEqual(memory_result['memory_cost'], expected_memory_cost) self.assertAlmostEqual(memory_result['memory_cost_ratio_to_bfloat16'], expected_memory_cost_ratio)
def estimate_compute_and_memory_cost(image_size, model_dir, hparams): """Estimate compute and memory cost of model.""" FLAGS.metadata_enabled = True input_shape = (1, image_size, image_size, 3) model, init_state = imagenet_train_utils.create_model( jax.random.PRNGKey(0), input_shape[0], input_shape[1], jnp.float32, hparams.model_hparams, train=False) hlo_proto = hlo_utils.load_hlo_proto_from_model(model, init_state, [input_shape]) del model, init_state cost_dict = compute_cost_utils.estimate_compute_cost(hlo_proto) memory_cost_dict = compute_cost_utils.estimate_memory_cost(hlo_proto) cost_dict.update(memory_cost_dict) FLAGS.metadata_enabled = False path = os.path.join(model_dir, COMPUTE_MEMORY_COST_FILENAME) with open(path, 'w') as file: json.dump(cost_dict, file, indent=2) logging.info('Estimated compute and memory costs and wrote to file')