Exemple #1
0
            def _reconcile_scaling_config_with_trial_resources(
                    self, scaling_config: ScalingConfig) -> ScalingConfig:
                """
                ResourceChangingScheduler workaround.

                Ensures that the scaling config matches trial resources.

                This should be replaced with RCS returning a ScalingConfig
                in the future.
                """

                trial_resources = self.trial_resources
                # This will be false if the resources are default
                if not isinstance(trial_resources, PlacementGroupFactory):
                    return scaling_config

                if scaling_config:
                    scaling_config = trainer_cls._validate_scaling_config(
                        scaling_config)
                scaling_config_from_trial_resources = (
                    ScalingConfig.from_placement_group_factory(trial_resources)
                )

                # This check should always pass if ResourceChangingScheduler is not
                # used.
                if scaling_config_from_trial_resources != scaling_config:
                    scaling_config = trainer_cls._validate_scaling_config(
                        scaling_config_from_trial_resources)
                return scaling_config
Exemple #2
0
def test_scaling_config_pgf_equivalance(trainer_resources,
                                        resources_per_worker_and_use_gpu,
                                        num_workers, placement_strategy):
    resources_per_worker, use_gpu = resources_per_worker_and_use_gpu
    scaling_config = ScalingConfig(
        trainer_resources=trainer_resources,
        num_workers=num_workers,
        resources_per_worker=resources_per_worker,
        use_gpu=use_gpu,
        placement_strategy=placement_strategy,
    )
    pgf = scaling_config.as_placement_group_factory()
    scaling_config_from_pgf = ScalingConfig.from_placement_group_factory(pgf)
    assert scaling_config == scaling_config_from_pgf
    assert scaling_config_from_pgf.as_placement_group_factory() == pgf