Esempio n. 1
0
    def test_batch_size_has_no_effect_on_cost(self, modelclass,
                                              input_shape_per_sample,
                                              model_kwargs):
        expected_compute_cost = None
        expected_memory_cost = None
        batch_size_list = [32, 64, 128, 256, 512, 1024]

        module = modelclass()

        # Sweep over the batch size list
        for batch_size in batch_size_list:
            input_shape = (batch_size, ) + input_shape_per_sample
            init_state = module.init(random.PRNGKey(0),
                                     jnp.ones(input_shape, jnp.float32),
                                     **model_kwargs)
            hlo_proto = hlo_utils.load_hlo_proto_from_model(
                module, init_state, [input_shape], **model_kwargs)
            del init_state
            compute_result = compute_cost_utils.estimate_compute_cost(
                hlo_proto)
            memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto)
            # Save the first cost and compare it with the rest
            if expected_compute_cost is None:
                expected_compute_cost = compute_result['compute_cost']
            else:
                self.assertEqual(compute_result['compute_cost'],
                                 expected_compute_cost)
            if expected_memory_cost is None:
                expected_memory_cost = memory_result['memory_cost']
            else:
                self.assertEqual(memory_result['memory_cost'],
                                 expected_memory_cost)
Esempio n. 2
0
    def test_estimate_simple_model_cost(
            self, modelclass, input_shapes, model_kwargs,
            expected_compute_cost, expected_compute_cost_ratio,
            expected_compute_cost_linear, expected_compute_cost_ratio_linear,
            expected_memory_cost, expected_memory_cost_ratio):
        module = modelclass()
        input_shapes_with_type = [(sh, jnp.float32) for sh in input_shapes]
        dummy_inputs = [
            jnp.ones(input_shape, dtype=dtype)
            for (input_shape, dtype) in input_shapes_with_type
        ]
        init_state = module.init(random.PRNGKey(0), *dummy_inputs,
                                 **model_kwargs)

        hlo_proto = hlo_utils.load_hlo_proto_from_model(
            module, init_state, input_shapes, **model_kwargs)
        compute_result = compute_cost_utils.estimate_compute_cost(hlo_proto)
        memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto)
        logging.info('compute cost result is %s', compute_result)
        logging.info('memory cost result is %s', memory_result)
        self.assertEqual(compute_result['compute_cost'], expected_compute_cost)
        self.assertEqual(memory_result['memory_cost'], expected_memory_cost)
        self.assertEqual(compute_result['compute_cost_ratio_to_bfloat16'],
                         expected_compute_cost_ratio)
        self.assertEqual(memory_result['memory_cost_ratio_to_bfloat16'],
                         expected_memory_cost_ratio)
        self.assertEqual(compute_result['compute_cost_linear'],
                         expected_compute_cost_linear)
        self.assertEqual(
            compute_result['compute_cost_ratio_to_bfloat16_linear'],
            expected_compute_cost_ratio_linear)
    def test_estimate_resnet_cost(self, base_config_filename,
                                  expected_compute_cost,
                                  expected_compute_cost_ratio,
                                  expected_memory_cost,
                                  expected_memory_cost_ratio):
        batch_size = 1024
        image_size = 224
        input_channels = 3
        input_shape = (batch_size, image_size, image_size, input_channels)

        logging.info('Testing for %s...', base_config_filename)
        hparams = hparams_utils.load_hparams_from_config_dict(
            hparams_config.TrainingHParams, models.ResNet.HParams,
            base_config_filename.get_config())

        hlo_proto = self._create_hlo_from_resnet_hparams(hparams, input_shape)
        compute_result = compute_cost_utils.estimate_compute_cost(hlo_proto)
        memory_result = compute_cost_utils.estimate_memory_cost(hlo_proto)
        self.assertEqual(compute_result['compute_cost'], expected_compute_cost)
        self.assertAlmostEqual(
            compute_result['compute_cost_ratio_to_bfloat16'],
            expected_compute_cost_ratio)
        self.assertEqual(memory_result['memory_cost'], expected_memory_cost)
        self.assertAlmostEqual(memory_result['memory_cost_ratio_to_bfloat16'],
                               expected_memory_cost_ratio)
Esempio n. 4
0
def estimate_compute_and_memory_cost(image_size, model_dir, hparams):
    """Estimate compute and memory cost of model."""
    FLAGS.metadata_enabled = True
    input_shape = (1, image_size, image_size, 3)
    model, init_state = imagenet_train_utils.create_model(
        jax.random.PRNGKey(0),
        input_shape[0],
        input_shape[1],
        jnp.float32,
        hparams.model_hparams,
        train=False)
    hlo_proto = hlo_utils.load_hlo_proto_from_model(model, init_state,
                                                    [input_shape])
    del model, init_state
    cost_dict = compute_cost_utils.estimate_compute_cost(hlo_proto)
    memory_cost_dict = compute_cost_utils.estimate_memory_cost(hlo_proto)
    cost_dict.update(memory_cost_dict)
    FLAGS.metadata_enabled = False

    path = os.path.join(model_dir, COMPUTE_MEMORY_COST_FILENAME)
    with open(path, 'w') as file:
        json.dump(cost_dict, file, indent=2)
    logging.info('Estimated compute and memory costs and wrote to file')